mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into version
This commit is contained in:
@@ -384,6 +384,10 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
m.initChannel("wecom_app", "WeCom App")
|
||||
}
|
||||
|
||||
if channels.Weixin.Enabled && channels.Weixin.Token() != "" {
|
||||
m.initChannel("weixin", "Weixin")
|
||||
}
|
||||
|
||||
if channels.Pico.Enabled && channels.Pico.Token() != "" {
|
||||
m.initChannel("pico", "Pico")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
package qq
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const qqVoiceMaxDuration = 60 * time.Second
|
||||
|
||||
func qqAudioDuration(localPath, filename, contentType string) (time.Duration, bool, error) {
|
||||
if localPath == "" {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
switch qqAudioDurationFormat(localPath, filename, contentType) {
|
||||
case "wav":
|
||||
return qqWAVDuration(localPath)
|
||||
case "ogg":
|
||||
return qqOggDuration(localPath)
|
||||
default:
|
||||
return 0, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func qqAudioDurationFormat(localPath, filename, contentType string) string {
|
||||
contentType = strings.ToLower(contentType)
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(contentType, "audio/wav"), strings.HasPrefix(contentType, "audio/x-wav"):
|
||||
return "wav"
|
||||
case strings.HasPrefix(contentType, "audio/ogg"),
|
||||
contentType == "application/ogg",
|
||||
contentType == "application/x-ogg":
|
||||
return "ogg"
|
||||
}
|
||||
|
||||
switch filepath.Ext(strings.ToLower(filename)) {
|
||||
case ".wav":
|
||||
return "wav"
|
||||
case ".ogg", ".opus":
|
||||
return "ogg"
|
||||
}
|
||||
|
||||
switch filepath.Ext(strings.ToLower(localPath)) {
|
||||
case ".wav":
|
||||
return "wav"
|
||||
case ".ogg", ".opus":
|
||||
return "ogg"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func qqWAVDuration(localPath string) (time.Duration, bool, error) {
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var header [12]byte
|
||||
if _, err := io.ReadFull(file, header[:]); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
var order binary.ByteOrder
|
||||
switch string(header[:4]) {
|
||||
case "RIFF":
|
||||
order = binary.LittleEndian
|
||||
case "RIFX":
|
||||
order = binary.BigEndian
|
||||
default:
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
if string(header[8:12]) != "WAVE" {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
var byteRate uint32
|
||||
var dataSize uint32
|
||||
var foundFmt bool
|
||||
var foundData bool
|
||||
|
||||
for {
|
||||
var chunkHeader [8]byte
|
||||
if _, err := io.ReadFull(file, chunkHeader[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
chunkSize := order.Uint32(chunkHeader[4:8])
|
||||
switch string(chunkHeader[:4]) {
|
||||
case "fmt ":
|
||||
chunkData := make([]byte, chunkSize)
|
||||
if _, err := io.ReadFull(file, chunkData); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if len(chunkData) >= 12 {
|
||||
byteRate = order.Uint32(chunkData[8:12])
|
||||
foundFmt = true
|
||||
}
|
||||
case "data":
|
||||
dataSize = chunkSize
|
||||
foundData = true
|
||||
if _, err := io.CopyN(io.Discard, file, int64(chunkSize)); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
default:
|
||||
if _, err := io.CopyN(io.Discard, file, int64(chunkSize)); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
}
|
||||
|
||||
if chunkSize%2 == 1 {
|
||||
if _, err := io.CopyN(io.Discard, file, 1); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
}
|
||||
|
||||
if foundFmt && foundData {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFmt || !foundData || byteRate == 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
durationNS := int64(dataSize) * int64(time.Second) / int64(byteRate)
|
||||
return time.Duration(durationNS), true, nil
|
||||
}
|
||||
|
||||
func qqOggDuration(localPath string) (time.Duration, bool, error) {
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var firstPacket []byte
|
||||
var codec string
|
||||
var sampleRate uint32
|
||||
var lastGranule uint64
|
||||
var haveGranule bool
|
||||
|
||||
for {
|
||||
var header [27]byte
|
||||
if _, err := io.ReadFull(file, header[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
if string(header[:4]) != "OggS" {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
pageSegments := int(header[26])
|
||||
segments := make([]byte, pageSegments)
|
||||
if _, err := io.ReadFull(file, segments); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
payloadLen := 0
|
||||
for _, segLen := range segments {
|
||||
payloadLen += int(segLen)
|
||||
}
|
||||
|
||||
payload := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(file, payload); err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
granule := binary.LittleEndian.Uint64(header[6:14])
|
||||
if granule != ^uint64(0) {
|
||||
lastGranule = granule
|
||||
haveGranule = true
|
||||
}
|
||||
|
||||
if codec == "" {
|
||||
offset := 0
|
||||
for _, segLen := range segments {
|
||||
firstPacket = append(firstPacket, payload[offset:offset+int(segLen)]...)
|
||||
offset += int(segLen)
|
||||
if segLen < 255 {
|
||||
codec, sampleRate = qqParseOggCodec(firstPacket)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !haveGranule || codec == "" {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
switch codec {
|
||||
case "opus":
|
||||
return time.Duration(lastGranule) * time.Second / 48000, true, nil
|
||||
case "vorbis":
|
||||
if sampleRate == 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
return time.Duration(lastGranule) * time.Second / time.Duration(sampleRate), true, nil
|
||||
default:
|
||||
return 0, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func qqParseOggCodec(packet []byte) (string, uint32) {
|
||||
if len(packet) >= 8 && string(packet[:8]) == "OpusHead" {
|
||||
return "opus", 48000
|
||||
}
|
||||
|
||||
if len(packet) >= 16 && packet[0] == 0x01 && string(packet[1:7]) == "vorbis" {
|
||||
sampleRate := binary.LittleEndian.Uint32(packet[12:16])
|
||||
if sampleRate > 0 {
|
||||
return "vorbis", sampleRate
|
||||
}
|
||||
}
|
||||
|
||||
return "", 0
|
||||
}
|
||||
+53
-4
@@ -387,12 +387,11 @@ func (c *QQChannel) uploadMedia(
|
||||
}
|
||||
|
||||
func (c *QQChannel) buildMediaUpload(part bus.MediaPart) (*qqMediaUpload, error) {
|
||||
payload := &qqMediaUpload{
|
||||
FileType: qqFileType(part.Type),
|
||||
}
|
||||
payload := &qqMediaUpload{}
|
||||
|
||||
mediaRef := part.Ref
|
||||
if isHTTPURL(mediaRef) {
|
||||
payload.FileType = qqFileType(c.outboundMediaType(part, ""))
|
||||
payload.URL = mediaRef
|
||||
return payload, nil
|
||||
}
|
||||
@@ -402,15 +401,23 @@ func (c *QQChannel) buildMediaUpload(part bus.MediaPart) (*qqMediaUpload, error)
|
||||
return nil, fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
resolved, err := store.Resolve(part.Ref)
|
||||
resolved, meta, err := store.ResolveWithMeta(part.Ref)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qq resolve media ref %q: %v: %w", part.Ref, err, channels.ErrSendFailed)
|
||||
}
|
||||
if part.Filename == "" {
|
||||
part.Filename = meta.Filename
|
||||
}
|
||||
if part.ContentType == "" {
|
||||
part.ContentType = meta.ContentType
|
||||
}
|
||||
|
||||
if isHTTPURL(resolved) {
|
||||
payload.FileType = qqFileType(c.outboundMediaType(part, ""))
|
||||
payload.URL = resolved
|
||||
return payload, nil
|
||||
}
|
||||
payload.FileType = qqFileType(c.outboundMediaType(part, resolved))
|
||||
|
||||
if limitBytes := c.maxBase64FileSizeBytes(); limitBytes > 0 {
|
||||
info, statErr := os.Stat(resolved)
|
||||
@@ -437,6 +444,48 @@ func (c *QQChannel) buildMediaUpload(part bus.MediaPart) (*qqMediaUpload, error)
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (c *QQChannel) outboundMediaType(part bus.MediaPart, localPath string) string {
|
||||
if part.Type != "audio" {
|
||||
return part.Type
|
||||
}
|
||||
|
||||
if localPath == "" {
|
||||
logger.InfoCF("qq", "Sending audio as file because duration is unavailable", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"filename": part.Filename,
|
||||
})
|
||||
return "file"
|
||||
}
|
||||
|
||||
duration, ok, err := qqAudioDuration(localPath, part.Filename, part.ContentType)
|
||||
if err != nil {
|
||||
logger.WarnCF("qq", "Failed to detect audio duration, sending as file", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"filename": part.Filename,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return "file"
|
||||
}
|
||||
if !ok {
|
||||
logger.InfoCF("qq", "Sending audio as file because duration is unavailable", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"filename": part.Filename,
|
||||
})
|
||||
return "file"
|
||||
}
|
||||
if duration > qqVoiceMaxDuration {
|
||||
logger.InfoCF("qq", "Sending audio as file because it exceeds QQ voice limit", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"filename": part.Filename,
|
||||
"duration_seconds": duration.Seconds(),
|
||||
"limit_seconds": qqVoiceMaxDuration.Seconds(),
|
||||
})
|
||||
return "file"
|
||||
}
|
||||
|
||||
return "audio"
|
||||
}
|
||||
|
||||
func (c *QQChannel) sendUploadedMedia(
|
||||
ctx context.Context,
|
||||
chatKind, chatID string,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package qq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
@@ -264,6 +266,142 @@ func TestSendMedia_UploadsLocalFileAsBase64(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_AudioAt60SecondsUsesVoiceUpload(t *testing.T) {
|
||||
assertAudioWAVUploadType(t, 60*time.Second, 3)
|
||||
}
|
||||
|
||||
func TestSendMedia_AudioOver60SecondsFallsBackToFileUpload(t *testing.T) {
|
||||
assertAudioWAVUploadType(t, 61*time.Second, 4)
|
||||
}
|
||||
|
||||
func assertAudioWAVUploadType(t *testing.T, duration time.Duration, wantFileType uint64) {
|
||||
t.Helper()
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
store := media.NewFileMediaStore()
|
||||
|
||||
localPath := writeWAVFile(t, t.TempDir(), "voice.wav", duration)
|
||||
ref, err := store.Store(localPath, media.MediaMeta{
|
||||
Filename: "voice.wav",
|
||||
ContentType: "audio/wav",
|
||||
}, "qq:test")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
api := &fakeQQAPI{
|
||||
transportResp: mustJSON(t, dto.Message{FileInfo: []byte("file-info")}),
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
ch.SetMediaStore(store)
|
||||
ch.chatType.Store("group-1", "group")
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "group-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "audio",
|
||||
Ref: ref,
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
if len(api.transportCalls) != 1 {
|
||||
t.Fatalf("transportCalls = %d, want 1", len(api.transportCalls))
|
||||
}
|
||||
if api.transportCalls[0].body.FileType != wantFileType {
|
||||
t.Fatalf("upload file_type = %d, want %d", api.transportCalls[0].body.FileType, wantFileType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_RemoteAudioFallsBackToFileUpload(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
api := &fakeQQAPI{
|
||||
transportResp: mustJSON(t, dto.Message{FileInfo: []byte("remote-file-info")}),
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
ch.chatType.Store("user-1", "direct")
|
||||
|
||||
err := ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "user-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "audio",
|
||||
Ref: "https://cdn.example.com/voice.ogg",
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
if len(api.transportCalls) != 1 {
|
||||
t.Fatalf("transportCalls = %d, want 1", len(api.transportCalls))
|
||||
}
|
||||
if api.transportCalls[0].body.FileType != 4 {
|
||||
t.Fatalf("upload file_type = %d, want 4", api.transportCalls[0].body.FileType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_LocalAudioWithUnknownDurationFallsBackToFileUpload(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
store := media.NewFileMediaStore()
|
||||
|
||||
localPath := writeTempFile(t, t.TempDir(), "voice.mp3", []byte("not-a-real-mp3"))
|
||||
ref, err := store.Store(localPath, media.MediaMeta{
|
||||
Filename: "voice.mp3",
|
||||
ContentType: "audio/mpeg",
|
||||
}, "qq:test")
|
||||
if err != nil {
|
||||
t.Fatalf("Store() error = %v", err)
|
||||
}
|
||||
|
||||
api := &fakeQQAPI{
|
||||
transportResp: mustJSON(t, dto.Message{FileInfo: []byte("file-info")}),
|
||||
}
|
||||
ch := &QQChannel{
|
||||
BaseChannel: channels.NewBaseChannel("qq", nil, messageBus, nil),
|
||||
api: api,
|
||||
dedup: make(map[string]time.Time),
|
||||
done: make(chan struct{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
ch.SetRunning(true)
|
||||
ch.SetMediaStore(store)
|
||||
ch.chatType.Store("group-1", "group")
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "group-1",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "audio",
|
||||
Ref: ref,
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
|
||||
if len(api.transportCalls) != 1 {
|
||||
t.Fatalf("transportCalls = %d, want 1", len(api.transportCalls))
|
||||
}
|
||||
if api.transportCalls[0].body.FileType != 4 {
|
||||
t.Fatalf("upload file_type = %d, want 4", api.transportCalls[0].body.FileType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_UsesRemoteURLUploadForC2C(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
api := &fakeQQAPI{
|
||||
@@ -494,3 +632,53 @@ func writeTempFile(t *testing.T, dir, name string, content []byte) string {
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func writeWAVFile(t *testing.T, dir, name string, duration time.Duration) string {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
sampleRate = 8000
|
||||
numChannels = 1
|
||||
bitsPerSample = 8
|
||||
)
|
||||
|
||||
dataSize := uint32(duration / time.Second * sampleRate * numChannels * (bitsPerSample / 8))
|
||||
byteRate := uint32(sampleRate * numChannels * (bitsPerSample / 8))
|
||||
blockAlign := uint16(numChannels * (bitsPerSample / 8))
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("RIFF")
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint32(36)+dataSize); err != nil {
|
||||
t.Fatalf("binary.Write(riff size) error = %v", err)
|
||||
}
|
||||
buf.WriteString("WAVE")
|
||||
buf.WriteString("fmt ")
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint32(16)); err != nil {
|
||||
t.Fatalf("binary.Write(fmt chunk size) error = %v", err)
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint16(1)); err != nil {
|
||||
t.Fatalf("binary.Write(audio format) error = %v", err)
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint16(numChannels)); err != nil {
|
||||
t.Fatalf("binary.Write(channels) error = %v", err)
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint32(sampleRate)); err != nil {
|
||||
t.Fatalf("binary.Write(sample rate) error = %v", err)
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, byteRate); err != nil {
|
||||
t.Fatalf("binary.Write(byte rate) error = %v", err)
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, blockAlign); err != nil {
|
||||
t.Fatalf("binary.Write(block align) error = %v", err)
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint16(bitsPerSample)); err != nil {
|
||||
t.Fatalf("binary.Write(bits per sample) error = %v", err)
|
||||
}
|
||||
buf.WriteString("data")
|
||||
if err := binary.Write(&buf, binary.LittleEndian, dataSize); err != nil {
|
||||
t.Fatalf("binary.Write(data size) error = %v", err)
|
||||
}
|
||||
buf.Write(make([]byte, dataSize))
|
||||
|
||||
return writeTempFile(t, dir, name, buf.Bytes())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
package weixin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
)
|
||||
|
||||
type ApiClient struct {
|
||||
BaseURL string
|
||||
Token string
|
||||
HttpClient *http.Client
|
||||
}
|
||||
|
||||
func NewApiClient(baseURL, token string, proxy string) (*ApiClient, error) {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://ilinkai.weixin.qq.com/"
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
// Default timeout; will be overridden per context
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
proxyURL, err := url.Parse(proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", proxy, err)
|
||||
}
|
||||
|
||||
// Clone the default transport so we preserve all default settings (TLS, HTTP/2, timeouts, keep-alives)
|
||||
if defaultTransport, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
transport := defaultTransport.Clone()
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
client.Transport = transport
|
||||
} else {
|
||||
// Fallback: preserve previous behavior if DefaultTransport is not the expected type
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ApiClient{
|
||||
BaseURL: baseURL,
|
||||
Token: token,
|
||||
HttpClient: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func randomWechatUIN() string {
|
||||
var b [4]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
uint32Val := binary.BigEndian.Uint32(b[:])
|
||||
return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%d", uint32Val)))
|
||||
}
|
||||
|
||||
func (c *ApiClient) post(ctx context.Context, endpoint string, body any, responseObj any) error {
|
||||
u, err := url.Parse(c.BaseURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Path = path.Join(u.Path, endpoint)
|
||||
|
||||
jsonData, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if endpoint == "ilink/bot/get_bot_qrcode" || endpoint == "ilink/bot/get_qrcode_status" {
|
||||
// QR routes have different headers sometimes, but let's stick to base ones
|
||||
if endpoint == "ilink/bot/get_qrcode_status" {
|
||||
// Use direct map assignment to send exact header name the Tencent API expects
|
||||
req.Header["iLink-App-ClientVersion"] = []string{"1"}
|
||||
}
|
||||
} else {
|
||||
req.Header["AuthorizationType"] = []string{"ilink_bot_token"}
|
||||
req.Header["X-WECHAT-UIN"] = []string{randomWechatUIN()}
|
||||
if c.Token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.Token)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http POST %s failed: %w", endpoint, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("http %d %s: %s", resp.StatusCode, resp.Status, string(respBody))
|
||||
}
|
||||
|
||||
if responseObj != nil {
|
||||
if err := json.Unmarshal(respBody, responseObj); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response: %w, body: %s", err, string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetUpdates(ctx context.Context, req GetUpdatesReq) (*GetUpdatesResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
var resp GetUpdatesResp
|
||||
err := c.post(ctx, "ilink/bot/getupdates", req, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) SendMessage(ctx context.Context, req SendMessageReq) (*SendMessageResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
var resp SendMessageResp
|
||||
if err := c.post(ctx, "ilink/bot/sendmessage", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetUploadUrl(ctx context.Context, req GetUploadUrlReq) (*GetUploadUrlResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
var resp GetUploadUrlResp
|
||||
err := c.post(ctx, "ilink/bot/getuploadurl", req, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetConfig(ctx context.Context, req GetConfigReq) (*GetConfigResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
var resp GetConfigResp
|
||||
if err := c.post(ctx, "ilink/bot/getconfig", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) SendTyping(ctx context.Context, req SendTypingReq) (*SendTypingResp, error) {
|
||||
req.BaseInfo = BaseInfo{ChannelVersion: "1.0.2"}
|
||||
var resp SendTypingResp
|
||||
if err := c.post(ctx, "ilink/bot/sendtyping", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetQRCode(ctx context.Context, botType string) (*QRCodeResponse, error) {
|
||||
// get_bot_qrcode is GET, not POST
|
||||
u, err := url.Parse(c.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "ilink/bot/get_bot_qrcode")
|
||||
q := u.Query()
|
||||
q.Set("bot_type", botType)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("get_bot_qrcode failed: %d %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var qrcodeResp QRCodeResponse
|
||||
if err := json.Unmarshal(respBody, &qrcodeResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &qrcodeResp, nil
|
||||
}
|
||||
|
||||
func (c *ApiClient) GetQRCodeStatus(ctx context.Context, qrcode string) (*StatusResponse, error) {
|
||||
// get_qrcode_status is GET
|
||||
u, err := url.Parse(c.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Path = path.Join(u.Path, "ilink/bot/get_qrcode_status")
|
||||
q := u.Query()
|
||||
q.Set("qrcode", qrcode)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header["iLink-App-ClientVersion"] = []string{"1"}
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("get_qrcode_status failed: %d %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var statusResp StatusResponse
|
||||
if err := json.Unmarshal(respBody, &statusResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &statusResp, nil
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package weixin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/mdp/qrterminal/v3"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// AuthFlowOpts configures the interactive QR login flow.
|
||||
type AuthFlowOpts struct {
|
||||
BaseURL string
|
||||
BotType string
|
||||
Timeout time.Duration
|
||||
Proxy string
|
||||
}
|
||||
|
||||
// PerformLoginInteractive starts the Weixin QR login flow and blocks until login is successful or times out.
|
||||
// It prints a QR code to the terminal for the user to scan.
|
||||
// Returns the BotToken, UserID, AccountID, and BaseUrl on success.
|
||||
func PerformLoginInteractive(
|
||||
ctx context.Context,
|
||||
opts AuthFlowOpts,
|
||||
) (botToken, userID, accountID, baseUrl string, err error) {
|
||||
if opts.BaseURL == "" {
|
||||
opts.BaseURL = "https://ilinkai.weixin.qq.com/"
|
||||
}
|
||||
if opts.BotType == "" {
|
||||
opts.BotType = "3" // Default iLink Bot Type
|
||||
}
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
api, err := NewApiClient(opts.BaseURL, "", opts.Proxy)
|
||||
if err != nil {
|
||||
return "", "", "", "", fmt.Errorf("failed to create api client: %w", err)
|
||||
}
|
||||
|
||||
logger.InfoC("weixin", "Requesting Weixin QR code...")
|
||||
qrResp, err := api.GetQRCode(ctx, opts.BotType)
|
||||
if err != nil {
|
||||
return "", "", "", "", fmt.Errorf("failed to get qrcode: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n=======================================================")
|
||||
fmt.Println("Please scan the following QR code with WeChat to login:")
|
||||
fmt.Println("=======================================================")
|
||||
fmt.Println()
|
||||
|
||||
// Create Small QR
|
||||
qrconfig := qrterminal.Config{
|
||||
Level: qrterminal.L,
|
||||
Writer: os.Stdout,
|
||||
HalfBlocks: true,
|
||||
}
|
||||
qrterminal.GenerateWithConfig(qrResp.QrcodeImgContent, qrconfig)
|
||||
|
||||
fmt.Printf("\nQR Code Link: %s\n\n", qrResp.QrcodeImgContent)
|
||||
fmt.Println("Waiting for scan...")
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, opts.Timeout)
|
||||
defer cancel()
|
||||
|
||||
pollTicker := time.NewTicker(2 * time.Second)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
scannedPrinted := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return "", "", "", "", fmt.Errorf("login timeout")
|
||||
case <-pollTicker.C:
|
||||
statusResp, err := api.GetQRCodeStatus(timeoutCtx, qrResp.Qrcode)
|
||||
if err != nil {
|
||||
// Long poll timeout or temporary error
|
||||
continue
|
||||
}
|
||||
|
||||
switch statusResp.Status {
|
||||
case "wait":
|
||||
// still waiting
|
||||
case "scaned":
|
||||
if !scannedPrinted {
|
||||
fmt.Println("👀 QR Code scanned! Please confirm login on your WeChat app...")
|
||||
scannedPrinted = true
|
||||
}
|
||||
case "confirmed":
|
||||
if statusResp.BotToken == "" || statusResp.IlinkBotID == "" {
|
||||
return "", "", "", "", fmt.Errorf("login confirmed but missing bot_token or ilink_bot_id")
|
||||
}
|
||||
logger.InfoCF("weixin", "Login successful", map[string]any{
|
||||
"account_id": statusResp.IlinkBotID,
|
||||
})
|
||||
|
||||
return statusResp.BotToken, statusResp.IlinkUserID, statusResp.IlinkBotID, statusResp.Baseurl, nil
|
||||
case "expired":
|
||||
return "", "", "", "", fmt.Errorf("qrcode expired, please try again")
|
||||
default:
|
||||
logger.WarnCF("weixin", "Unknown QR code status", map[string]any{
|
||||
"status": statusResp.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,226 @@
|
||||
package weixin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
basechannels "github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
weixinDefaultCDNBaseURL = "https://novac2c.cdn.weixin.qq.com/c2c"
|
||||
weixinConfigCacheTTL = 24 * time.Hour
|
||||
weixinConfigRetryInitial = 2 * time.Second
|
||||
weixinConfigRetryMax = time.Hour
|
||||
weixinSessionPauseDuration = time.Hour
|
||||
weixinSessionExpiredCode = -14
|
||||
)
|
||||
|
||||
type typingTicketCacheEntry struct {
|
||||
ticket string
|
||||
nextFetchAt time.Time
|
||||
retryDelay time.Duration
|
||||
}
|
||||
|
||||
type syncCursorFile struct {
|
||||
GetUpdatesBuf string `json:"get_updates_buf"`
|
||||
}
|
||||
|
||||
func picoclawHomeDir() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
userHome, _ := os.UserHomeDir()
|
||||
return filepath.Join(userHome, ".picoclaw")
|
||||
}
|
||||
|
||||
func buildWeixinSyncBufPath(cfg config.WeixinConfig) string {
|
||||
key := "default"
|
||||
token := strings.TrimSpace(cfg.Token())
|
||||
if token != "" {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(cfg.BaseURL) + "|" + token))
|
||||
key = hex.EncodeToString(sum[:8])
|
||||
}
|
||||
return filepath.Join(picoclawHomeDir(), "channels", "weixin", "sync", key+".json")
|
||||
}
|
||||
|
||||
func loadGetUpdatesBuf(path string) (string, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
var decoded syncCursorFile
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return decoded.GetUpdatesBuf, nil
|
||||
}
|
||||
|
||||
func saveGetUpdatesBuf(path, cursor string) error {
|
||||
data, err := json.Marshal(syncCursorFile{GetUpdatesBuf: cursor})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) cdnBaseURL() string {
|
||||
if base := strings.TrimSpace(c.config.CDNBaseURL); base != "" {
|
||||
return strings.TrimRight(base, "/")
|
||||
}
|
||||
return weixinDefaultCDNBaseURL
|
||||
}
|
||||
|
||||
func isSessionExpiredStatus(ret, errcode int) bool {
|
||||
return ret == weixinSessionExpiredCode || errcode == weixinSessionExpiredCode
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) pauseSession(operation string, ret, errcode int, errmsg string) time.Duration {
|
||||
c.pauseMu.Lock()
|
||||
defer c.pauseMu.Unlock()
|
||||
|
||||
until := time.Now().Add(weixinSessionPauseDuration)
|
||||
if until.After(c.pauseUntil) {
|
||||
c.pauseUntil = until
|
||||
}
|
||||
|
||||
remaining := time.Until(c.pauseUntil)
|
||||
logger.ErrorCF("weixin", "Session expired; pausing Weixin channel", map[string]any{
|
||||
"operation": operation,
|
||||
"ret": ret,
|
||||
"errcode": errcode,
|
||||
"errmsg": errmsg,
|
||||
"until": c.pauseUntil.Format(time.RFC3339),
|
||||
"minutes": int((remaining + time.Minute - 1) / time.Minute),
|
||||
})
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) remainingPause() time.Duration {
|
||||
c.pauseMu.Lock()
|
||||
defer c.pauseMu.Unlock()
|
||||
|
||||
if c.pauseUntil.IsZero() {
|
||||
return 0
|
||||
}
|
||||
remaining := time.Until(c.pauseUntil)
|
||||
if remaining <= 0 {
|
||||
c.pauseUntil = time.Time{}
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) waitWhileSessionPaused(ctx context.Context) error {
|
||||
remaining := c.remainingPause()
|
||||
if remaining <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
timer := time.NewTimer(remaining)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) ensureSessionActive() error {
|
||||
remaining := c.remainingPause()
|
||||
if remaining <= 0 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(
|
||||
"weixin session paused (%d min remaining): %w",
|
||||
int((remaining+time.Minute-1)/time.Minute),
|
||||
basechannels.ErrSendFailed,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) getTypingTicket(ctx context.Context, userID string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
c.typingMu.Lock()
|
||||
entry, ok := c.typingCache[userID]
|
||||
if ok && now.Before(entry.nextFetchAt) {
|
||||
ticket := entry.ticket
|
||||
c.typingMu.Unlock()
|
||||
return ticket, nil
|
||||
}
|
||||
cachedTicket := entry.ticket
|
||||
retryDelay := entry.retryDelay
|
||||
c.typingMu.Unlock()
|
||||
|
||||
contextToken := ""
|
||||
if v, ok := c.contextTokens.Load(userID); ok {
|
||||
contextToken, _ = v.(string)
|
||||
}
|
||||
|
||||
resp, err := c.api.GetConfig(ctx, GetConfigReq{
|
||||
IlinkUserID: userID,
|
||||
ContextToken: contextToken,
|
||||
})
|
||||
if err == nil && resp != nil && resp.Ret == 0 && resp.Errcode == 0 {
|
||||
ticket := strings.TrimSpace(resp.TypingTicket)
|
||||
c.typingMu.Lock()
|
||||
c.typingCache[userID] = typingTicketCacheEntry{
|
||||
ticket: ticket,
|
||||
nextFetchAt: now.Add(weixinConfigCacheTTL),
|
||||
retryDelay: weixinConfigRetryInitial,
|
||||
}
|
||||
c.typingMu.Unlock()
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
if resp != nil && isSessionExpiredStatus(resp.Ret, resp.Errcode) {
|
||||
c.pauseSession("getconfig", resp.Ret, resp.Errcode, resp.Errmsg)
|
||||
}
|
||||
|
||||
if retryDelay <= 0 {
|
||||
retryDelay = weixinConfigRetryInitial
|
||||
} else {
|
||||
retryDelay *= 2
|
||||
if retryDelay > weixinConfigRetryMax {
|
||||
retryDelay = weixinConfigRetryMax
|
||||
}
|
||||
}
|
||||
|
||||
c.typingMu.Lock()
|
||||
c.typingCache[userID] = typingTicketCacheEntry{
|
||||
ticket: cachedTicket,
|
||||
nextFetchAt: now.Add(retryDelay),
|
||||
retryDelay: retryDelay,
|
||||
}
|
||||
c.typingMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return cachedTicket, err
|
||||
}
|
||||
if resp == nil {
|
||||
return cachedTicket, fmt.Errorf("getconfig returned nil response")
|
||||
}
|
||||
return cachedTicket, fmt.Errorf(
|
||||
"getconfig failed: ret=%d errcode=%d errmsg=%s",
|
||||
resp.Ret,
|
||||
resp.Errcode,
|
||||
resp.Errmsg,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
package weixin
|
||||
|
||||
// BaseInfo is attached to every outgoing CGI request
|
||||
type BaseInfo struct {
|
||||
ChannelVersion string `json:"channel_version,omitempty"`
|
||||
}
|
||||
|
||||
type APIStatus struct {
|
||||
Ret int `json:"ret,omitempty"`
|
||||
Errcode int `json:"errcode,omitempty"`
|
||||
Errmsg string `json:"errmsg,omitempty"`
|
||||
}
|
||||
|
||||
// UploadMediaType constants
|
||||
const (
|
||||
UploadMediaTypeImage = 1
|
||||
UploadMediaTypeVideo = 2
|
||||
UploadMediaTypeFile = 3
|
||||
UploadMediaTypeVoice = 4
|
||||
)
|
||||
|
||||
type GetUploadUrlReq struct {
|
||||
Filekey string `json:"filekey,omitempty"`
|
||||
MediaType int `json:"media_type,omitempty"`
|
||||
ToUserID string `json:"to_user_id,omitempty"`
|
||||
Rawsize int64 `json:"rawsize,omitempty"`
|
||||
RawfileMD5 string `json:"rawfilemd5,omitempty"`
|
||||
Filesize int64 `json:"filesize,omitempty"`
|
||||
ThumbRawsize int64 `json:"thumb_rawsize,omitempty"`
|
||||
ThumbRawfileMD5 string `json:"thumb_rawfilemd5,omitempty"`
|
||||
ThumbFilesize int64 `json:"thumb_filesize,omitempty"`
|
||||
NoNeedThumb bool `json:"no_need_thumb,omitempty"`
|
||||
Aeskey string `json:"aeskey,omitempty"` // hex-encoded 16-byte AES key
|
||||
BaseInfo BaseInfo `json:"base_info,omitempty"`
|
||||
}
|
||||
|
||||
type GetUploadUrlResp struct {
|
||||
APIStatus
|
||||
UploadParam string `json:"upload_param,omitempty"`
|
||||
ThumbUploadParam string `json:"thumb_upload_param,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
MessageTypeNone = 0
|
||||
MessageTypeUser = 1
|
||||
MessageTypeBot = 2
|
||||
)
|
||||
|
||||
const (
|
||||
MessageItemTypeNone = 0
|
||||
MessageItemTypeText = 1
|
||||
MessageItemTypeImage = 2
|
||||
MessageItemTypeVoice = 3
|
||||
MessageItemTypeFile = 4
|
||||
MessageItemTypeVideo = 5
|
||||
)
|
||||
|
||||
const (
|
||||
MessageStateNew = 0
|
||||
MessageStateGenerating = 1
|
||||
MessageStateFinish = 2
|
||||
)
|
||||
|
||||
type TextItem struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type CDNMedia struct {
|
||||
EncryptQueryParam string `json:"encrypt_query_param,omitempty"`
|
||||
AesKey string `json:"aes_key,omitempty"` // base64 encoded
|
||||
EncryptType int `json:"encrypt_type,omitempty"`
|
||||
}
|
||||
|
||||
type ImageItem struct {
|
||||
Media *CDNMedia `json:"media,omitempty"`
|
||||
ThumbMedia *CDNMedia `json:"thumb_media,omitempty"`
|
||||
Aeskey string `json:"aeskey,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
MidSize int64 `json:"mid_size,omitempty"`
|
||||
ThumbSize int64 `json:"thumb_size,omitempty"`
|
||||
ThumbHeight int `json:"thumb_height,omitempty"`
|
||||
ThumbWidth int `json:"thumb_width,omitempty"`
|
||||
HDSize int64 `json:"hd_size,omitempty"`
|
||||
}
|
||||
|
||||
type VoiceItem struct {
|
||||
Media *CDNMedia `json:"media,omitempty"`
|
||||
EncodeType int `json:"encode_type,omitempty"`
|
||||
BitsPerSample int `json:"bits_per_sample,omitempty"`
|
||||
SampleRate int `json:"sample_rate,omitempty"`
|
||||
Playtime int `json:"playtime,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type FileItem struct {
|
||||
Media *CDNMedia `json:"media,omitempty"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
MD5 string `json:"md5,omitempty"`
|
||||
Len string `json:"len,omitempty"`
|
||||
}
|
||||
|
||||
type VideoItem struct {
|
||||
Media *CDNMedia `json:"media,omitempty"`
|
||||
VideoSize int64 `json:"video_size,omitempty"`
|
||||
PlayLength int `json:"play_length,omitempty"`
|
||||
VideoMD5 string `json:"video_md5,omitempty"`
|
||||
ThumbMedia *CDNMedia `json:"thumb_media,omitempty"`
|
||||
ThumbSize int64 `json:"thumb_size,omitempty"`
|
||||
ThumbHeight int `json:"thumb_height,omitempty"`
|
||||
ThumbWidth int `json:"thumb_width,omitempty"`
|
||||
}
|
||||
|
||||
type RefMessage struct {
|
||||
MessageItem *MessageItem `json:"message_item,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type MessageItem struct {
|
||||
Type int `json:"type,omitempty"`
|
||||
CreateTimeMs int64 `json:"create_time_ms,omitempty"`
|
||||
UpdateTimeMs int64 `json:"update_time_ms,omitempty"`
|
||||
IsCompleted bool `json:"is_completed,omitempty"`
|
||||
MsgID string `json:"msg_id,omitempty"`
|
||||
RefMsg *RefMessage `json:"ref_msg,omitempty"`
|
||||
TextItem *TextItem `json:"text_item,omitempty"`
|
||||
ImageItem *ImageItem `json:"image_item,omitempty"`
|
||||
VoiceItem *VoiceItem `json:"voice_item,omitempty"`
|
||||
FileItem *FileItem `json:"file_item,omitempty"`
|
||||
VideoItem *VideoItem `json:"video_item,omitempty"`
|
||||
}
|
||||
|
||||
type WeixinMessage struct {
|
||||
Seq int `json:"seq,omitempty"`
|
||||
MessageID int64 `json:"message_id,omitempty"`
|
||||
FromUserID string `json:"from_user_id,omitempty"`
|
||||
ToUserID string `json:"to_user_id,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
CreateTimeMs int64 `json:"create_time_ms,omitempty"`
|
||||
UpdateTimeMs int64 `json:"update_time_ms,omitempty"`
|
||||
DeleteTimeMs int64 `json:"delete_time_ms,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
GroupID string `json:"group_id,omitempty"`
|
||||
MessageType int `json:"message_type,omitempty"`
|
||||
MessageState int `json:"message_state,omitempty"`
|
||||
ItemList []MessageItem `json:"item_list,omitempty"`
|
||||
ContextToken string `json:"context_token,omitempty"`
|
||||
}
|
||||
|
||||
type GetUpdatesReq struct {
|
||||
SyncBuf string `json:"sync_buf,omitempty"`
|
||||
GetUpdatesBuf string `json:"get_updates_buf,omitempty"`
|
||||
BaseInfo BaseInfo `json:"base_info,omitempty"`
|
||||
}
|
||||
|
||||
type GetUpdatesResp struct {
|
||||
APIStatus
|
||||
Msgs []WeixinMessage `json:"msgs,omitempty"`
|
||||
SyncBuf string `json:"sync_buf,omitempty"`
|
||||
GetUpdatesBuf string `json:"get_updates_buf,omitempty"`
|
||||
LongpollingTimeoutMs int `json:"longpolling_timeout_ms,omitempty"`
|
||||
}
|
||||
|
||||
type SendMessageReq struct {
|
||||
Msg WeixinMessage `json:"msg,omitempty"`
|
||||
BaseInfo BaseInfo `json:"base_info,omitempty"`
|
||||
}
|
||||
|
||||
type SendMessageResp struct {
|
||||
APIStatus
|
||||
}
|
||||
|
||||
type GetConfigReq struct {
|
||||
IlinkUserID string `json:"ilink_user_id,omitempty"`
|
||||
ContextToken string `json:"context_token,omitempty"`
|
||||
BaseInfo BaseInfo `json:"base_info,omitempty"`
|
||||
}
|
||||
|
||||
type GetConfigResp struct {
|
||||
APIStatus
|
||||
TypingTicket string `json:"typing_ticket,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
TypingStatusTyping = 1
|
||||
TypingStatusCancel = 2
|
||||
)
|
||||
|
||||
type SendTypingReq struct {
|
||||
IlinkUserID string `json:"ilink_user_id,omitempty"`
|
||||
TypingTicket string `json:"typing_ticket,omitempty"`
|
||||
Status int `json:"status,omitempty"` // 1=typing, 2=cancel
|
||||
BaseInfo BaseInfo `json:"base_info,omitempty"`
|
||||
}
|
||||
|
||||
type SendTypingResp struct {
|
||||
APIStatus
|
||||
}
|
||||
|
||||
type QRCodeResponse struct {
|
||||
Qrcode string `json:"qrcode"`
|
||||
QrcodeImgContent string `json:"qrcode_img_content"`
|
||||
}
|
||||
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"` // "wait", "scaned", "confirmed", "expired"
|
||||
BotToken string `json:"bot_token,omitempty"`
|
||||
IlinkBotID string `json:"ilink_bot_id,omitempty"`
|
||||
Baseurl string `json:"baseurl,omitempty"`
|
||||
IlinkUserID string `json:"ilink_user_id,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,359 @@
|
||||
package weixin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// WeixinChannel is the Weixin channel implementation over Tencent iLink REST API.
|
||||
type WeixinChannel struct {
|
||||
*channels.BaseChannel
|
||||
api *ApiClient
|
||||
config config.WeixinConfig
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
bus *bus.MessageBus
|
||||
// contextTokens stores the last context_token per user (from_user_id → context_token).
|
||||
// This is required by the iLink API to associate replies with the right chat session.
|
||||
contextTokens sync.Map
|
||||
typingMu sync.Mutex
|
||||
typingCache map[string]typingTicketCacheEntry
|
||||
pauseMu sync.Mutex
|
||||
pauseUntil time.Time
|
||||
syncBufPath string
|
||||
}
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("weixin", func(cfg *config.Config, bus *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewWeixinChannel(cfg.Channels.Weixin, bus)
|
||||
})
|
||||
}
|
||||
|
||||
// NewWeixinChannel creates a new WeixinChannel from config.
|
||||
func NewWeixinChannel(cfg config.WeixinConfig, messageBus *bus.MessageBus) (*WeixinChannel, error) {
|
||||
api, err := NewApiClient(cfg.BaseURL, cfg.Token(), cfg.Proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("weixin: failed to create API client: %w", err)
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"weixin",
|
||||
cfg,
|
||||
messageBus,
|
||||
cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(4000),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &WeixinChannel{
|
||||
BaseChannel: base,
|
||||
api: api,
|
||||
config: cfg,
|
||||
bus: messageBus,
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
syncBufPath: buildWeixinSyncBufPath(cfg),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("weixin", "Starting Weixin channel")
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
c.SetRunning(true)
|
||||
go c.pollLoop(c.ctx)
|
||||
logger.InfoC("weixin", "Weixin channel started")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WeixinChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("weixin", "Stopping Weixin channel")
|
||||
c.SetRunning(false)
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pollLoop is the long-poll receive loop. It runs until ctx is canceled.
|
||||
func (c *WeixinChannel) pollLoop(ctx context.Context) {
|
||||
const (
|
||||
defaultPollTimeoutMs = 35_000
|
||||
retryDelay = 2 * time.Second
|
||||
backoffDelay = 30 * time.Second
|
||||
maxConsecutiveFails = 3
|
||||
)
|
||||
|
||||
consecutiveFails := 0
|
||||
getUpdatesBuf, err := loadGetUpdatesBuf(c.syncBufPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("weixin", "Failed to load persisted get_updates_buf", map[string]any{
|
||||
"path": c.syncBufPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
getUpdatesBuf = ""
|
||||
} else if getUpdatesBuf != "" {
|
||||
logger.InfoCF("weixin", "Resuming persisted get_updates_buf", map[string]any{
|
||||
"path": c.syncBufPath,
|
||||
"bytes": len(getUpdatesBuf),
|
||||
"source": "disk",
|
||||
})
|
||||
}
|
||||
nextTimeoutMs := defaultPollTimeoutMs
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.InfoC("weixin", "Weixin poll loop stopped")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err := c.waitWhileSessionPaused(ctx); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Build a context with timeout slightly longer than the long-poll
|
||||
pollCtx, pollCancel := context.WithTimeout(ctx, time.Duration(nextTimeoutMs+5000)*time.Millisecond)
|
||||
|
||||
resp, err := c.api.GetUpdates(pollCtx, GetUpdatesReq{
|
||||
GetUpdatesBuf: getUpdatesBuf,
|
||||
})
|
||||
pollCancel()
|
||||
|
||||
if err != nil {
|
||||
// Check if we're shutting down
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
consecutiveFails++
|
||||
logger.WarnCF("weixin", "getUpdates failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
"attempt": consecutiveFails,
|
||||
})
|
||||
|
||||
if consecutiveFails >= maxConsecutiveFails {
|
||||
logger.ErrorCF("weixin", "Too many consecutive failures, backing off", map[string]any{
|
||||
"duration": backoffDelay,
|
||||
})
|
||||
consecutiveFails = 0
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(backoffDelay):
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(retryDelay):
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isSessionExpiredStatus(resp.Ret, resp.Errcode) {
|
||||
remaining := c.pauseSession("getupdates", resp.Ret, resp.Errcode, resp.Errmsg)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(remaining):
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.Errcode != 0 || resp.Ret != 0 {
|
||||
consecutiveFails++
|
||||
logger.ErrorCF("weixin", "getUpdates API error", map[string]any{
|
||||
"ret": resp.Ret,
|
||||
"errcode": resp.Errcode,
|
||||
"errmsg": resp.Errmsg,
|
||||
})
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(retryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
consecutiveFails = 0
|
||||
|
||||
// Update the long-poll timeout from server hint
|
||||
if resp.LongpollingTimeoutMs > 0 {
|
||||
nextTimeoutMs = resp.LongpollingTimeoutMs
|
||||
}
|
||||
|
||||
// Advance cursor
|
||||
if resp.GetUpdatesBuf != "" {
|
||||
getUpdatesBuf = resp.GetUpdatesBuf
|
||||
if err := saveGetUpdatesBuf(c.syncBufPath, getUpdatesBuf); err != nil {
|
||||
logger.WarnCF("weixin", "Failed to persist get_updates_buf", map[string]any{
|
||||
"path": c.syncBufPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch messages
|
||||
for _, msg := range resp.Msgs {
|
||||
c.handleInboundMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleInboundMessage converts a WeixinMessage to a bus.InboundMessage.
|
||||
func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMessage) {
|
||||
fromUserID := msg.FromUserID
|
||||
if fromUserID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
messageID := msg.ClientID
|
||||
if messageID == "" {
|
||||
messageID = uuid.New().String()
|
||||
}
|
||||
|
||||
// Build text content from item_list
|
||||
var parts []string
|
||||
for _, item := range msg.ItemList {
|
||||
switch item.Type {
|
||||
case MessageItemTypeText:
|
||||
if item.TextItem != nil && item.TextItem.Text != "" {
|
||||
parts = append(parts, item.TextItem.Text)
|
||||
}
|
||||
case MessageItemTypeVoice:
|
||||
if item.VoiceItem != nil && item.VoiceItem.Text != "" {
|
||||
// Use voice → text transcription from server
|
||||
parts = append(parts, item.VoiceItem.Text)
|
||||
} else {
|
||||
parts = append(parts, "[audio]")
|
||||
}
|
||||
case MessageItemTypeImage:
|
||||
parts = append(parts, "[image]")
|
||||
case MessageItemTypeFile:
|
||||
if item.FileItem != nil && item.FileItem.FileName != "" {
|
||||
parts = append(parts, fmt.Sprintf("[file: %s]", item.FileItem.FileName))
|
||||
} else {
|
||||
parts = append(parts, "[file]")
|
||||
}
|
||||
case MessageItemTypeVideo:
|
||||
parts = append(parts, "[video]")
|
||||
}
|
||||
}
|
||||
|
||||
var mediaRefs []string
|
||||
if mediaItem := selectInboundMediaItem(msg); mediaItem != nil {
|
||||
ref, err := c.downloadMediaFromItem(ctx, fromUserID, messageID, mediaItem)
|
||||
if err != nil {
|
||||
logger.ErrorCF("weixin", "Failed to download inbound media", map[string]any{
|
||||
"from_user_id": fromUserID,
|
||||
"message_id": messageID,
|
||||
"type": mediaItem.Type,
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else if ref != "" {
|
||||
mediaRefs = append(mediaRefs, ref)
|
||||
}
|
||||
}
|
||||
|
||||
content := strings.Join(parts, "\n")
|
||||
if content == "" && len(mediaRefs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "weixin",
|
||||
PlatformID: fromUserID,
|
||||
CanonicalID: identity.BuildCanonicalID("weixin", fromUserID),
|
||||
Username: fromUserID,
|
||||
DisplayName: fromUserID,
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
logger.DebugCF("weixin", "Message rejected by allowlist", map[string]any{
|
||||
"from_user_id": fromUserID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: "direct", ID: fromUserID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"from_user_id": fromUserID,
|
||||
"context_token": msg.ContextToken,
|
||||
"session_id": msg.SessionID,
|
||||
}
|
||||
|
||||
logger.DebugCF("weixin", "Received message", map[string]any{
|
||||
"from_user_id": fromUserID,
|
||||
"content_len": len(content),
|
||||
"media_count": len(mediaRefs),
|
||||
})
|
||||
|
||||
// Store context_token for outbound reply association
|
||||
if msg.ContextToken != "" {
|
||||
c.contextTokens.Store(fromUserID, msg.ContextToken)
|
||||
}
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender)
|
||||
}
|
||||
|
||||
// Send implements channels.Channel by sending a text message to the WeChat user.
|
||||
func (c *WeixinChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
if err := c.ensureSessionActive(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.Content == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We need a context_token to send a reply. It should be stored in the conversation metadata.
|
||||
// The chat_id is the weixin user_id (from_user_id).
|
||||
toUserID := msg.ChatID
|
||||
|
||||
// Retrieve context_token from our per-user map (stored on last inbound)
|
||||
contextToken := ""
|
||||
if ct, ok := c.contextTokens.Load(toUserID); ok {
|
||||
contextToken, _ = ct.(string)
|
||||
}
|
||||
|
||||
// If we don't have a context token for this user, we cannot send a valid reply.
|
||||
// Treat this as a non-temporary error so the manager doesn't keep retrying.
|
||||
if contextToken == "" {
|
||||
logger.ErrorCF("weixin", "Missing context token, cannot send message", map[string]any{
|
||||
"to_user_id": toUserID,
|
||||
})
|
||||
return fmt.Errorf("weixin send: %w: missing context token for chat %s", channels.ErrSendFailed, toUserID)
|
||||
}
|
||||
|
||||
if err := c.sendTextMessage(ctx, toUserID, contextToken, msg.Content); err != nil {
|
||||
logger.ErrorCF("weixin", "Failed to send message", map[string]any{
|
||||
"to_user_id": toUserID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
if c.remainingPause() > 0 {
|
||||
return fmt.Errorf("weixin send: %w", channels.ErrSendFailed)
|
||||
}
|
||||
return fmt.Errorf("weixin send: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
package weixin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
basechannels "github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestParseWeixinMediaAESKey(t *testing.T) {
|
||||
raw := []byte("1234567890abcdef")
|
||||
|
||||
got, err := parseWeixinMediaAESKey(base64.StdEncoding.EncodeToString(raw))
|
||||
if err != nil {
|
||||
t.Fatalf("parseWeixinMediaAESKey(raw) error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, raw) {
|
||||
t.Fatalf("parseWeixinMediaAESKey(raw) = %x, want %x", got, raw)
|
||||
}
|
||||
|
||||
hexEncoded := base64.StdEncoding.EncodeToString([]byte("31323334353637383930616263646566"))
|
||||
got, err = parseWeixinMediaAESKey(hexEncoded)
|
||||
if err != nil {
|
||||
t.Fatalf("parseWeixinMediaAESKey(hex-string) error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, raw) {
|
||||
t.Fatalf("parseWeixinMediaAESKey(hex-string) = %x, want %x", got, raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAndDecryptCDNBuffer(t *testing.T) {
|
||||
key := []byte("1234567890abcdef")
|
||||
plaintext := []byte("hello weixin")
|
||||
ciphertext, err := encryptAESECB(plaintext, key)
|
||||
if err != nil {
|
||||
t.Fatalf("encryptAESECB() error = %v", err)
|
||||
}
|
||||
|
||||
ch := &WeixinChannel{
|
||||
api: &ApiClient{
|
||||
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path != "/download" {
|
||||
t.Fatalf("download path = %q, want /download", r.URL.Path)
|
||||
}
|
||||
if r.URL.Query().Get("encrypted_query_param") != "token" {
|
||||
t.Fatalf("encrypted_query_param = %q, want token", r.URL.Query().Get("encrypted_query_param"))
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(ciphertext)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.downloadAndDecryptCDNBuffer(context.Background(), "token", key)
|
||||
if err != nil {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, plaintext) {
|
||||
t.Fatalf("downloadAndDecryptCDNBuffer() = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadBufferToCDN(t *testing.T) {
|
||||
key := []byte("1234567890abcdef")
|
||||
plaintext := []byte("upload me")
|
||||
wantCipher, err := encryptAESECB(plaintext, key)
|
||||
if err != nil {
|
||||
t.Fatalf("encryptAESECB() error = %v", err)
|
||||
}
|
||||
|
||||
ch := &WeixinChannel{
|
||||
api: &ApiClient{
|
||||
HttpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path != "/upload" {
|
||||
t.Fatalf("upload path = %q, want /upload", r.URL.Path)
|
||||
}
|
||||
if got := r.URL.Query().Get("encrypted_query_param"); got != "upload-param" {
|
||||
t.Fatalf("encrypted_query_param = %q, want upload-param", got)
|
||||
}
|
||||
if got := r.URL.Query().Get("filekey"); got != "file-key" {
|
||||
t.Fatalf("filekey = %q, want file-key", got)
|
||||
}
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
if !bytes.Equal(body, wantCipher) {
|
||||
t.Fatalf("upload body = %x, want %x", body, wantCipher)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||
Header: http.Header{
|
||||
"X-Encrypted-Param": []string{"download-param"},
|
||||
},
|
||||
}, nil
|
||||
})},
|
||||
},
|
||||
config: config.WeixinConfig{
|
||||
CDNBaseURL: "https://cdn.example.com",
|
||||
},
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
got, err := ch.uploadBufferToCDN(context.Background(), plaintext, "upload-param", "file-key", key)
|
||||
if err != nil {
|
||||
t.Fatalf("uploadBufferToCDN() error = %v", err)
|
||||
}
|
||||
if got != "download-param" {
|
||||
t.Fatalf("uploadBufferToCDN() = %q, want download-param", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSaveGetUpdatesBuf(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "sync.json")
|
||||
|
||||
if err := saveGetUpdatesBuf(path, "cursor-123"); err != nil {
|
||||
t.Fatalf("saveGetUpdatesBuf() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := loadGetUpdatesBuf(path)
|
||||
if err != nil {
|
||||
t.Fatalf("loadGetUpdatesBuf() error = %v", err)
|
||||
}
|
||||
if got != "cursor-123" {
|
||||
t.Fatalf("loadGetUpdatesBuf() = %q, want cursor-123", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWeixinSyncBufPathUsesPicoclawHome(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv(config.EnvHome, home)
|
||||
|
||||
wxCfg := config.WeixinConfig{
|
||||
BaseURL: "https://ilinkai.weixin.qq.com/",
|
||||
}
|
||||
wxCfg.SetToken("token-123")
|
||||
got := buildWeixinSyncBufPath(wxCfg)
|
||||
if filepath.Dir(got) != filepath.Join(home, "channels", "weixin", "sync") {
|
||||
t.Fatalf("sync path dir = %q", filepath.Dir(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionPauseGuard(t *testing.T) {
|
||||
ch := &WeixinChannel{
|
||||
typingCache: make(map[string]typingTicketCacheEntry),
|
||||
}
|
||||
|
||||
ch.pauseSession("getupdates", 0, weixinSessionExpiredCode, "expired")
|
||||
|
||||
if err := ch.ensureSessionActive(); !errors.Is(err, basechannels.ErrSendFailed) {
|
||||
t.Fatalf("ensureSessionActive() error = %v, want ErrSendFailed", err)
|
||||
}
|
||||
|
||||
ch.pauseMu.Lock()
|
||||
ch.pauseUntil = time.Now().Add(-time.Second)
|
||||
ch.pauseMu.Unlock()
|
||||
|
||||
if err := ch.ensureSessionActive(); err != nil {
|
||||
t.Fatalf("ensureSessionActive() after expiry error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectInboundMediaItemFallsBackToRefMessage(t *testing.T) {
|
||||
msg := WeixinMessage{
|
||||
ItemList: []MessageItem{
|
||||
{
|
||||
Type: MessageItemTypeText,
|
||||
TextItem: &TextItem{
|
||||
Text: "look",
|
||||
},
|
||||
RefMsg: &RefMessage{
|
||||
MessageItem: &MessageItem{
|
||||
Type: MessageItemTypeImage,
|
||||
ImageItem: &ImageItem{
|
||||
Media: &CDNMedia{
|
||||
EncryptQueryParam: "abc",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
item := selectInboundMediaItem(msg)
|
||||
if item == nil {
|
||||
t.Fatal("selectInboundMediaItem() = nil, want ref media item")
|
||||
}
|
||||
if item.Type != MessageItemTypeImage {
|
||||
t.Fatalf("selectInboundMediaItem().Type = %d, want %d", item.Type, MessageItemTypeImage)
|
||||
}
|
||||
}
|
||||
+53
-19
@@ -259,6 +259,7 @@ type AgentDefaults struct {
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
|
||||
LogLevel string `json:"log_level,omitempty" env:"PICOCLAW_LOG_LEVEL"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -307,6 +308,7 @@ type ChannelsConfig struct {
|
||||
WeCom WeComConfig `json:"wecom"`
|
||||
WeComApp WeComAppConfig `json:"wecom_app"`
|
||||
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
|
||||
Weixin WeixinConfig `json:"weixin"`
|
||||
Pico PicoConfig `json:"pico"`
|
||||
PicoClient PicoClientConfig `json:"pico_client"`
|
||||
IRC IRCConfig `json:"irc"`
|
||||
@@ -751,6 +753,27 @@ func (c *WeComAIBotConfig) SetSecret(secret string) {
|
||||
c.secDirty = true
|
||||
}
|
||||
|
||||
type WeixinConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WEIXIN_ENABLED"`
|
||||
token string
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_WEIXIN_BASE_URL"`
|
||||
CDNBaseURL string `json:"cdn_base_url" env:"PICOCLAW_CHANNELS_WEIXIN_CDN_BASE_URL"`
|
||||
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_WEIXIN_PROXY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WEIXIN_ALLOW_FROM"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WEIXIN_REASONING_CHANNEL_ID"`
|
||||
secDirty bool
|
||||
}
|
||||
|
||||
func (c *WeixinConfig) Token() string {
|
||||
return c.token
|
||||
}
|
||||
|
||||
func (c *WeixinConfig) SetToken(token string) *WeixinConfig {
|
||||
c.token = token
|
||||
c.secDirty = true
|
||||
return c
|
||||
}
|
||||
|
||||
type PicoConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
|
||||
token string
|
||||
@@ -1391,82 +1414,87 @@ func applySecurityConfig(cfg *Config, sec *SecurityConfig) error {
|
||||
|
||||
// Handle Telegram token
|
||||
if sec.Channels.Telegram != nil && sec.Channels.Telegram.Token != "" {
|
||||
cfg.Channels.Telegram.SetToken(sec.Channels.Telegram.Token)
|
||||
cfg.Channels.Telegram.token = sec.Channels.Telegram.Token
|
||||
}
|
||||
|
||||
// Handle Feishu credentials
|
||||
if sec.Channels.Feishu != nil {
|
||||
if sec.Channels.Feishu.AppSecret != "" {
|
||||
cfg.Channels.Feishu.SetAppSecret(sec.Channels.Feishu.AppSecret)
|
||||
cfg.Channels.Feishu.appSecret = sec.Channels.Feishu.AppSecret
|
||||
}
|
||||
if sec.Channels.Feishu.EncryptKey != "" {
|
||||
cfg.Channels.Feishu.SetEncryptKey(sec.Channels.Feishu.EncryptKey)
|
||||
cfg.Channels.Feishu.encryptKey = sec.Channels.Feishu.EncryptKey
|
||||
}
|
||||
if sec.Channels.Feishu.VerificationToken != "" {
|
||||
cfg.Channels.Feishu.SetVerificationToken(sec.Channels.Feishu.VerificationToken)
|
||||
cfg.Channels.Feishu.verificationToken = sec.Channels.Feishu.VerificationToken
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Discord token
|
||||
if sec.Channels.Discord != nil && sec.Channels.Discord.Token != "" {
|
||||
cfg.Channels.Discord.SetToken(sec.Channels.Discord.Token)
|
||||
cfg.Channels.Discord.token = sec.Channels.Discord.Token
|
||||
}
|
||||
|
||||
// Handle Weixin token
|
||||
if sec.Channels.Weixin != nil && sec.Channels.Weixin.Token != "" {
|
||||
cfg.Channels.Discord.token = sec.Channels.Discord.Token
|
||||
}
|
||||
|
||||
// Handle DingTalk client secret
|
||||
if sec.Channels.DingTalk != nil && sec.Channels.DingTalk.ClientSecret != "" {
|
||||
cfg.Channels.DingTalk.SetClientSecret(sec.Channels.DingTalk.ClientSecret)
|
||||
cfg.Channels.DingTalk.clientSecret = sec.Channels.DingTalk.ClientSecret
|
||||
}
|
||||
|
||||
// Handle Slack tokens
|
||||
if sec.Channels.Slack != nil {
|
||||
if sec.Channels.Slack.BotToken != "" {
|
||||
cfg.Channels.Slack.SetBotToken(sec.Channels.Slack.BotToken)
|
||||
cfg.Channels.Slack.botToken = sec.Channels.Slack.BotToken
|
||||
}
|
||||
if sec.Channels.Slack.AppToken != "" {
|
||||
cfg.Channels.Slack.SetAppToken(sec.Channels.Slack.AppToken)
|
||||
cfg.Channels.Slack.appToken = sec.Channels.Slack.AppToken
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Matrix access token
|
||||
if sec.Channels.Matrix != nil && sec.Channels.Matrix.AccessToken != "" {
|
||||
cfg.Channels.Matrix.SetAccessToken(sec.Channels.Matrix.AccessToken)
|
||||
cfg.Channels.Matrix.accessToken = sec.Channels.Matrix.AccessToken
|
||||
}
|
||||
|
||||
// Handle LINE credentials
|
||||
if sec.Channels.LINE != nil {
|
||||
if sec.Channels.LINE.ChannelSecret != "" {
|
||||
cfg.Channels.LINE.SetChannelSecret(sec.Channels.LINE.ChannelSecret)
|
||||
cfg.Channels.LINE.channelSecret = sec.Channels.LINE.ChannelSecret
|
||||
}
|
||||
if sec.Channels.LINE.ChannelAccessToken != "" {
|
||||
cfg.Channels.LINE.SetChannelAccessToken(sec.Channels.LINE.ChannelAccessToken)
|
||||
cfg.Channels.LINE.channelAccessToken = sec.Channels.LINE.ChannelAccessToken
|
||||
}
|
||||
}
|
||||
|
||||
// Handle OneBot access token
|
||||
if sec.Channels.OneBot != nil && sec.Channels.OneBot.AccessToken != "" {
|
||||
cfg.Channels.OneBot.SetAccessToken(sec.Channels.OneBot.AccessToken)
|
||||
cfg.Channels.OneBot.accessToken = sec.Channels.OneBot.AccessToken
|
||||
}
|
||||
|
||||
// Handle WeCom token and encoding key
|
||||
if sec.Channels.WeCom != nil {
|
||||
if sec.Channels.WeCom.Token != "" {
|
||||
cfg.Channels.WeCom.SetToken(sec.Channels.WeCom.Token)
|
||||
cfg.Channels.WeCom.token = sec.Channels.WeCom.Token
|
||||
}
|
||||
if sec.Channels.WeCom.EncodingAESKey != "" {
|
||||
cfg.Channels.WeCom.SetEncodingAESKey(sec.Channels.WeCom.EncodingAESKey)
|
||||
cfg.Channels.WeCom.encodingAESKey = sec.Channels.WeCom.EncodingAESKey
|
||||
}
|
||||
}
|
||||
|
||||
// Handle WeCom App credentials
|
||||
if sec.Channels.WeComApp != nil {
|
||||
if sec.Channels.WeComApp.CorpSecret != "" {
|
||||
cfg.Channels.WeComApp.SetCorpSecret(sec.Channels.WeComApp.CorpSecret)
|
||||
cfg.Channels.WeComApp.corpSecret = sec.Channels.WeComApp.CorpSecret
|
||||
}
|
||||
if sec.Channels.WeComApp.Token != "" {
|
||||
cfg.Channels.WeComApp.SetToken(sec.Channels.WeComApp.Token)
|
||||
cfg.Channels.WeComApp.token = sec.Channels.WeComApp.Token
|
||||
}
|
||||
if sec.Channels.WeComApp.EncodingAESKey != "" {
|
||||
cfg.Channels.WeComApp.SetEncodingAESKey(sec.Channels.WeComApp.EncodingAESKey)
|
||||
cfg.Channels.WeComApp.encodingAESKey = sec.Channels.WeComApp.EncodingAESKey
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1485,7 +1513,7 @@ func applySecurityConfig(cfg *Config, sec *SecurityConfig) error {
|
||||
|
||||
// Handle Pico channel token
|
||||
if sec.Channels.Pico != nil && sec.Channels.Pico.Token != "" {
|
||||
cfg.Channels.Pico.SetToken(sec.Channels.Pico.Token)
|
||||
cfg.Channels.Pico.token = sec.Channels.Pico.Token
|
||||
}
|
||||
|
||||
// Handle IRC passwords
|
||||
@@ -1503,7 +1531,7 @@ func applySecurityConfig(cfg *Config, sec *SecurityConfig) error {
|
||||
|
||||
// Handle QQ app secret
|
||||
if sec.Channels.QQ != nil && sec.Channels.QQ.AppSecret != "" {
|
||||
cfg.Channels.QQ.SetAppSecret(sec.Channels.QQ.AppSecret)
|
||||
cfg.Channels.QQ.appSecret = sec.Channels.QQ.AppSecret
|
||||
}
|
||||
|
||||
cfg.security = sec
|
||||
@@ -1649,6 +1677,12 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
}
|
||||
cfg.Channels.Discord.secDirty = false
|
||||
}
|
||||
if cfg.Channels.Weixin.secDirty {
|
||||
cfg.security.Channels.Weixin = &WeixinSecurity{
|
||||
Token: cfg.Channels.Weixin.Token(),
|
||||
}
|
||||
cfg.Channels.Discord.secDirty = false
|
||||
}
|
||||
if cfg.Channels.QQ.secDirty {
|
||||
cfg.security.Channels.QQ = &QQSecurity{
|
||||
AppSecret: cfg.Channels.QQ.AppSecret(),
|
||||
|
||||
@@ -88,6 +88,7 @@ type channelsConfigV0 struct {
|
||||
Feishu feishuConfigV0 `json:"feishu"`
|
||||
Discord discordConfigV0 `json:"discord"`
|
||||
MaixCam maixcamConfigV0 `json:"maixcam"`
|
||||
Weixin weixinConfigV0 `json:"weixin"`
|
||||
QQ qqConfigV0 `json:"qq"`
|
||||
DingTalk dingtalkConfigV0 `json:"dingtalk"`
|
||||
Slack slackConfigV0 `json:"slack"`
|
||||
@@ -107,6 +108,7 @@ func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity)
|
||||
discord, discordSecurity := v.Discord.ToDiscordConfig()
|
||||
maixcam := v.MaixCam.ToMaixCamConfig()
|
||||
qq, qqSecurity := v.QQ.ToQQConfig()
|
||||
weixin, weixinSecurity := v.Weixin.ToWeiXinConfig()
|
||||
dingtalk, dingtalkSecurity := v.DingTalk.ToDingTalkConfig()
|
||||
slack, slackSecurity := v.Slack.ToSlackConfig()
|
||||
matrix, matrixSecurity := v.Matrix.ToMatrixConfig()
|
||||
@@ -125,6 +127,7 @@ func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity)
|
||||
Discord: discord,
|
||||
MaixCam: maixcam,
|
||||
QQ: qq,
|
||||
Weixin: weixin,
|
||||
DingTalk: dingtalk,
|
||||
Slack: slack,
|
||||
Matrix: matrix,
|
||||
@@ -140,6 +143,7 @@ func (v *channelsConfigV0) ToChannelsConfig() (ChannelsConfig, ChannelsSecurity)
|
||||
Feishu: &feishuSecurity,
|
||||
Discord: &discordSecurity,
|
||||
QQ: &qqSecurity,
|
||||
Weixin: &weixinSecurity,
|
||||
DingTalk: &dingtalkSecurity,
|
||||
Slack: &slackSecurity,
|
||||
Matrix: &matrixSecurity,
|
||||
@@ -463,6 +467,30 @@ func (v *wecomConfigV0) ToWeComConfig() (WeComConfig, WeComSecurity) {
|
||||
}
|
||||
}
|
||||
|
||||
type weixinConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WEIXIN_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_WEIXIN_TOKEN"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_WEIXIN_BASE_URL"`
|
||||
CDNBaseURL string `json:"cdn_base_url" env:"PICOCLAW_CHANNELS_WEIXIN_CDN_BASE_URL"`
|
||||
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_WEIXIN_PROXY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WEIXIN_ALLOW_FROM"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WEIXIN_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
func (v *weixinConfigV0) ToWeiXinConfig() (WeixinConfig, WeixinSecurity) {
|
||||
return WeixinConfig{
|
||||
Enabled: v.Enabled,
|
||||
token: v.Token,
|
||||
BaseURL: v.BaseURL,
|
||||
CDNBaseURL: v.CDNBaseURL,
|
||||
Proxy: v.Proxy,
|
||||
AllowFrom: v.AllowFrom,
|
||||
ReasoningChannelID: v.ReasoningChannelID,
|
||||
}, WeixinSecurity{
|
||||
Token: v.Token,
|
||||
}
|
||||
}
|
||||
|
||||
type wecomappConfigV0 struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
|
||||
CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
|
||||
|
||||
@@ -443,6 +443,13 @@ func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_LogLevel(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.Agents.Defaults.LogLevel != "fatal" {
|
||||
t.Errorf("LogLevel = %q, want \"fatal\"", cfg.Agents.Defaults.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ExecAllowRemoteDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
@@ -1052,3 +1059,38 @@ func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey(), plainKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigParsesLogLevel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
data := `{"version":1,"agents":{"defaults":{"log_level":"debug"}}}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
if cfg.Agents.Defaults.LogLevel != "debug" {
|
||||
t.Errorf("LogLevel = %q, want \"debug\"", cfg.Agents.Defaults.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigLogLevelEmpty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
data := `{}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
// When config omits log_level, the DefaultConfig value ("fatal") is preserved.
|
||||
if cfg.Agents.Defaults.LogLevel != "fatal" {
|
||||
t.Errorf("LogLevel = %q, want \"fatal\"", cfg.Agents.Defaults.LogLevel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ func DefaultConfig() *Config {
|
||||
Version: CurrentVersion,
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
LogLevel: "fatal",
|
||||
Workspace: workspacePath,
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
@@ -155,6 +156,13 @@ func DefaultConfig() *Config {
|
||||
WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?",
|
||||
ProcessingMessage: DefaultWeComAIBotProcessingMessage,
|
||||
},
|
||||
Weixin: WeixinConfig{
|
||||
Enabled: false,
|
||||
BaseURL: "https://ilinkai.weixin.qq.com/",
|
||||
CDNBaseURL: "https://novac2c.cdn.weixin.qq.com/c2c",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
Proxy: "",
|
||||
},
|
||||
Pico: PicoConfig{
|
||||
Enabled: false,
|
||||
PingInterval: 30,
|
||||
|
||||
@@ -47,6 +47,7 @@ type ChannelsSecurity struct {
|
||||
Telegram *TelegramSecurity `yaml:"telegram,omitempty"`
|
||||
Feishu *FeishuSecurity `yaml:"feishu,omitempty"`
|
||||
Discord *DiscordSecurity `yaml:"discord,omitempty"`
|
||||
Weixin *WeixinSecurity `yaml:"weixin,omitempty"`
|
||||
QQ *QQSecurity `yaml:"qq,omitempty"`
|
||||
DingTalk *DingTalkSecurity `yaml:"dingtalk,omitempty"`
|
||||
Slack *SlackSecurity `yaml:"slack,omitempty"`
|
||||
@@ -74,6 +75,10 @@ type DiscordSecurity struct {
|
||||
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
|
||||
}
|
||||
|
||||
type WeixinSecurity struct {
|
||||
Token string `yaml:"token,omitempty" env:"PICOCLAW_CHANNELS_WEIXIN_TOKEN"`
|
||||
}
|
||||
|
||||
type QQSecurity struct {
|
||||
AppSecret string `yaml:"app_secret,omitempty" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/wecom"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/weixin"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp_native"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -79,16 +80,18 @@ func (p *startupBlockedProvider) GetDefaultModel() string {
|
||||
|
||||
// Run starts the gateway runtime using the configuration loaded from configPath.
|
||||
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %w", err)
|
||||
}
|
||||
|
||||
logger.SetLevelFromString(cfg.Agents.Defaults.LogLevel)
|
||||
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating provider: %w", err)
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
const (
|
||||
minIntervalMinutes = 5
|
||||
defaultIntervalMinutes = 30
|
||||
userTasksMarker = "Add your heartbeat tasks below this line:"
|
||||
)
|
||||
|
||||
// HeartbeatHandler is the function type for handling heartbeat.
|
||||
@@ -232,7 +233,7 @@ func (hs *HeartbeatService) buildPrompt() string {
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if len(content) == 0 {
|
||||
if !heartbeatHasUserTasks(content) {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -284,6 +285,32 @@ Add your heartbeat tasks below this line:
|
||||
}
|
||||
}
|
||||
|
||||
func heartbeatHasUserTasks(content string) bool {
|
||||
trimmed := strings.TrimSpace(content)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
markerIdx := strings.Index(content, userTasksMarker)
|
||||
if markerIdx < 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
tasksSection := content[markerIdx+len(userTasksMarker):]
|
||||
for _, line := range strings.Split(tasksSection, "\n") {
|
||||
trimmedLine := strings.TrimSpace(line)
|
||||
if trimmedLine == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimmedLine, "#") {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// sendResponse sends the heartbeat response to the last channel
|
||||
func (hs *HeartbeatService) sendResponse(response string) {
|
||||
hs.mu.RLock()
|
||||
|
||||
@@ -3,6 +3,7 @@ package heartbeat
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -203,3 +204,47 @@ func TestHeartbeatFilePath(t *testing.T) {
|
||||
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_DefaultTemplateStaysIdle(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.createDefaultHeartbeatTemplate()
|
||||
|
||||
if prompt := hs.buildPrompt(); prompt != "" {
|
||||
t.Fatalf("buildPrompt() = %q, want empty prompt for untouched default template", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_UserTasksAfterMarkerProducePrompt(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.createDefaultHeartbeatTemplate()
|
||||
|
||||
path := filepath.Join(tmpDir, "HEARTBEAT.md")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read HEARTBEAT.md: %v", err)
|
||||
}
|
||||
updated := string(data) + "\n- Check unread Feishu messages\n"
|
||||
if err := os.WriteFile(path, []byte(updated), 0o644); err != nil {
|
||||
t.Fatalf("Failed to update HEARTBEAT.md: %v", err)
|
||||
}
|
||||
|
||||
prompt := hs.buildPrompt()
|
||||
if prompt == "" {
|
||||
t.Fatal("buildPrompt() = empty, want non-empty prompt when user tasks are present")
|
||||
}
|
||||
if !strings.Contains(prompt, "Check unread Feishu messages") {
|
||||
t.Fatalf("prompt = %q, want user task content", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,13 +94,18 @@ func MatchAllowed(sender bus.SenderInfo, allowed string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// isNumeric returns true if s consists entirely of digits.
|
||||
// isNumeric returns true if s consists entirely of digits, allowing for an optional leading minus sign
|
||||
// (required for Telegram group/channel IDs like -1001234567890).
|
||||
func isNumeric(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if r < '0' || r > '9' {
|
||||
start := 0
|
||||
if s[0] == '-' && len(s) > 1 {
|
||||
start = 1
|
||||
}
|
||||
for i := start; i < len(s); i++ {
|
||||
if s[i] < '0' || s[i] > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +97,15 @@ func TestMatchAllowed(t *testing.T) {
|
||||
allowed: "654321",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "negative numeric ID matches PlatformID",
|
||||
sender: bus.SenderInfo{
|
||||
Platform: "telegram",
|
||||
PlatformID: "-1001234567890",
|
||||
},
|
||||
allowed: "-1001234567890",
|
||||
want: true,
|
||||
},
|
||||
// Username matching
|
||||
{
|
||||
name: "@username matches Username",
|
||||
@@ -238,6 +247,9 @@ func TestIsNumeric(t *testing.T) {
|
||||
{"abc", false},
|
||||
{"12a34", false},
|
||||
{"telegram", false},
|
||||
{"-1001234567890", true},
|
||||
{"-", false},
|
||||
{"-12a34", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -106,6 +106,36 @@ func GetLevel() LogLevel {
|
||||
return currentLevel
|
||||
}
|
||||
|
||||
// ParseLevel converts a case-insensitive level name to a LogLevel.
|
||||
// Returns the level and true if valid, or (INFO, false) if unrecognized.
|
||||
func ParseLevel(s string) (LogLevel, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "debug":
|
||||
return DEBUG, true
|
||||
case "info":
|
||||
return INFO, true
|
||||
case "warn", "warning":
|
||||
return WARN, true
|
||||
case "error":
|
||||
return ERROR, true
|
||||
case "fatal":
|
||||
return FATAL, true
|
||||
default:
|
||||
return INFO, false
|
||||
}
|
||||
}
|
||||
|
||||
// SetLevelFromString sets the log level from a string value.
|
||||
// If the string is empty or not a recognized level name, the current level is kept.
|
||||
func SetLevelFromString(s string) {
|
||||
if s == "" {
|
||||
return
|
||||
}
|
||||
if level, ok := ParseLevel(s); ok {
|
||||
SetLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
func EnableFileLogging(filePath string) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
@@ -252,3 +252,88 @@ func TestFormatFieldValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultLevelIsInfo(t *testing.T) {
|
||||
// The package-level default (before any SetLevel call) should be INFO.
|
||||
// Because earlier tests may have changed it, we just verify the constant is wired correctly.
|
||||
if logLevelNames[INFO] != "INFO" {
|
||||
t.Errorf("INFO constant mapped to %q, want \"INFO\"", logLevelNames[INFO])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLevelValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want LogLevel
|
||||
}{
|
||||
{"debug", DEBUG},
|
||||
{"DEBUG", DEBUG},
|
||||
{"Debug", DEBUG},
|
||||
{"info", INFO},
|
||||
{"INFO", INFO},
|
||||
{"warn", WARN},
|
||||
{"WARN", WARN},
|
||||
{"warning", WARN},
|
||||
{"WARNING", WARN},
|
||||
{"error", ERROR},
|
||||
{"ERROR", ERROR},
|
||||
{"fatal", FATAL},
|
||||
{"FATAL", FATAL},
|
||||
{" info ", INFO},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got, ok := ParseLevel(tt.input)
|
||||
if !ok {
|
||||
t.Fatalf("ParseLevel(%q) returned ok=false, want true", tt.input)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ParseLevel(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLevelInvalid(t *testing.T) {
|
||||
tests := []string{"", "garbage", "verbose", "trace", "critical"}
|
||||
|
||||
for _, input := range tests {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
_, ok := ParseLevel(input)
|
||||
if ok {
|
||||
t.Errorf("ParseLevel(%q) returned ok=true, want false", input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLevelFromString(t *testing.T) {
|
||||
initialLevel := GetLevel()
|
||||
defer SetLevel(initialLevel)
|
||||
|
||||
// Valid string changes the level
|
||||
SetLevel(INFO)
|
||||
SetLevelFromString("error")
|
||||
if got := GetLevel(); got != ERROR {
|
||||
t.Errorf("after SetLevelFromString(\"error\"): GetLevel() = %v, want ERROR", got)
|
||||
}
|
||||
|
||||
// Empty string is a no-op
|
||||
SetLevelFromString("")
|
||||
if got := GetLevel(); got != ERROR {
|
||||
t.Errorf("after SetLevelFromString(\"\"): GetLevel() = %v, want ERROR (unchanged)", got)
|
||||
}
|
||||
|
||||
// Invalid string is a no-op
|
||||
SetLevelFromString("garbage")
|
||||
if got := GetLevel(); got != ERROR {
|
||||
t.Errorf("after SetLevelFromString(\"garbage\"): GetLevel() = %v, want ERROR (unchanged)", got)
|
||||
}
|
||||
|
||||
// Case-insensitive
|
||||
SetLevelFromString("FATAL")
|
||||
if got := GetLevel(); got != FATAL {
|
||||
t.Errorf("after SetLevelFromString(\"FATAL\"): GetLevel() = %v, want FATAL", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user