Merge branch 'upstream-main' into feat/subturn-poc

This commit is contained in:
Administrator
2026-03-18 22:57:01 +08:00
117 changed files with 14857 additions and 7091 deletions
+24 -5
View File
@@ -52,7 +52,7 @@ func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuil
}
func getGlobalConfigDir() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
home, err := os.UserHomeDir()
@@ -65,7 +65,7 @@ func getGlobalConfigDir() string {
func NewContextBuilder(workspace string) *ContextBuilder {
// builtin skills: skills directory in current project
// Use the skills/ directory under the current working directory
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
builtinSkillsDir := strings.TrimSpace(os.Getenv(config.EnvBuiltinSkills))
if builtinSkillsDir == "" {
wd, _ := os.Getwd()
builtinSkillsDir = filepath.Join(wd, "skills")
@@ -458,7 +458,23 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string {
//
// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
// See: https://platform.openai.com/docs/guides/prompt-caching
func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
func formatCurrentSenderLine(senderID, senderDisplayName string) string {
senderID = strings.TrimSpace(senderID)
senderDisplayName = strings.TrimSpace(senderDisplayName)
switch {
case senderDisplayName != "" && senderID != "":
return fmt.Sprintf("Current sender: %s (ID: %s)", senderDisplayName, senderID)
case senderDisplayName != "":
return fmt.Sprintf("Current sender: %s", senderDisplayName)
case senderID != "":
return fmt.Sprintf("Current sender: %s", senderID)
default:
return ""
}
}
func (cb *ContextBuilder) buildDynamicContext(channel, chatID, senderID, senderDisplayName string) string {
now := time.Now().Format("2006-01-02 15:04 (Monday)")
rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version())
@@ -468,6 +484,9 @@ func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
if channel != "" && chatID != "" {
fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID)
}
if senderLine := formatCurrentSenderLine(senderID, senderDisplayName); senderLine != "" {
fmt.Fprintf(&sb, "\n\n## Current Sender\n%s", senderLine)
}
return sb.String()
}
@@ -477,7 +496,7 @@ func (cb *ContextBuilder) BuildMessages(
summary string,
currentMessage string,
media []string,
channel, chatID string,
channel, chatID, senderID, senderDisplayName string,
) []providers.Message {
messages := []providers.Message{}
@@ -493,7 +512,7 @@ func (cb *ContextBuilder) BuildMessages(
staticPrompt := cb.BuildSystemPromptWithCache()
// Build short dynamic context (time, runtime, session) — changes per request
dynamicCtx := cb.buildDynamicContext(channel, chatID)
dynamicCtx := cb.buildDynamicContext(channel, chatID, senderID, senderDisplayName)
// Compose a single system message: static (cached) + dynamic + optional summary.
// Keeping all system content in one message ensures every provider adapter can
+65 -3
View File
@@ -82,7 +82,7 @@ func TestSingleSystemMessage(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1")
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "")
systemCount := 0
for _, m := range msgs {
@@ -126,6 +126,68 @@ func TestSingleSystemMessage(t *testing.T) {
}
}
func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nTest agent.",
})
defer os.RemoveAll(tmpDir)
cb := NewContextBuilder(tmpDir)
tests := []struct {
name string
senderID string
senderDisplayName string
wantLine string
wantSection bool
}{
{
name: "both id and display name",
senderID: "feishu:ou_xxx",
senderDisplayName: "Zhang San",
wantLine: "Current sender: Zhang San (ID: feishu:ou_xxx)",
wantSection: true,
},
{
name: "display name only",
senderDisplayName: "Alice",
wantLine: "Current sender: Alice",
wantSection: true,
},
{
name: "id only",
senderID: "discord:123",
wantLine: "Current sender: discord:123",
wantSection: true,
},
{
name: "no sender info",
wantSection: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName)
sys := msgs[0].Content
if tt.wantSection {
if !strings.Contains(sys, "## Current Sender") {
t.Fatalf("system prompt missing Current Sender section:\n%s", sys)
}
if !strings.Contains(sys, tt.wantLine) {
t.Fatalf("system prompt missing sender line %q:\n%s", tt.wantLine, sys)
}
return
}
if strings.Contains(sys, "## Current Sender") {
t.Fatalf("system prompt should omit Current Sender section:\n%s", sys)
}
})
}
}
// TestMtimeAutoInvalidation verifies that the cache detects source file changes
// via mtime without requiring explicit InvalidateCache().
// Fix: original implementation had no auto-invalidation — edits to bootstrap files,
@@ -576,7 +638,7 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
}
// Also exercise BuildMessages concurrently
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat")
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat", "", "")
if len(msgs) < 2 {
errs <- "BuildMessages returned fewer than 2 messages"
return
@@ -664,6 +726,6 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test")
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test", "", "")
}
}
+72 -15
View File
@@ -61,6 +61,8 @@ type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
SenderID string // Current sender ID for dynamic context
SenderDisplayName string // Current sender display name for dynamic context
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
DefaultResponse string // Response when LLM returns empty
@@ -166,7 +168,12 @@ func registerSharedTools(
}
}
if cfg.Tools.IsToolEnabled("web_fetch") {
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
fetchTool, err := tools.NewWebFetchToolWithProxy(
50000,
cfg.Tools.Web.Proxy,
cfg.Tools.Web.Format,
cfg.Tools.Web.FetchLimitBytes,
cfg.Tools.Web.PrivateHostWhitelist)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
@@ -338,10 +345,9 @@ func (al *AgentLoop) Run(ctx context.Context) error {
select {
case <-ctx.Done():
return nil
default:
msg, ok := al.bus.ConsumeInbound(ctx)
case msg, ok := <-al.bus.InboundChan():
if !ok {
continue
return nil
}
// Start a goroutine that drains the bus while processMessage is
@@ -408,6 +414,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
}
}()
default:
time.Sleep(time.Microsecond * 200)
}
}
@@ -419,9 +427,15 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
for {
msg, ok := al.bus.ConsumeInbound(ctx)
if !ok {
var msg bus.InboundMessage
select {
case <-ctx.Done():
return
case m, ok := <-al.bus.InboundChan():
if !ok {
return
}
msg = m
}
// Transcribe audio if needed before steering, so the agent sees text.
@@ -861,14 +875,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
})
opts := processOptions{
SessionKey: sessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
SessionKey: sessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
SenderID: msg.SenderID,
SenderDisplayName: msg.Sender.DisplayName,
UserMessage: msg.Content,
Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
}
// context-dependent commands check their own Runtime fields and report
@@ -1039,6 +1055,8 @@ func (al *AgentLoop) runAgentLoop(
opts.Media,
opts.Channel,
opts.ChatID,
opts.SenderID,
opts.SenderDisplayName,
)
// Resolve media:// refs: images→base64 data URLs, non-images→local paths in content
@@ -1256,6 +1274,19 @@ func (al *AgentLoop) runLLMIteration(
// Build tool definitions
providerToolDefs := agent.Tools.ToProviderDefs()
// Determine whether the provider's native web search should replace
// the client-side web_search tool for this request. Only enable when web
// search is actually enabled and registered (so users who disabled web
// access do not get provider-side search or billing).
_, hasWebSearch := agent.Tools.Get("web_search")
useNativeSearch := al.cfg.Tools.Web.PreferNative &&
isNativeSearchProvider(agent.Provider) &&
hasWebSearch
if useNativeSearch {
providerToolDefs = filterClientWebSearch(providerToolDefs)
}
// Log LLM request details
logger.DebugCF("agent", "LLM request",
map[string]any{
@@ -1264,6 +1295,7 @@ func (al *AgentLoop) runLLMIteration(
"model": activeModel,
"messages_count": len(messages),
"tools_count": len(providerToolDefs),
"native_search": useNativeSearch,
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
"system_prompt_len": len(messages[0].Content),
@@ -1286,6 +1318,9 @@ func (al *AgentLoop) runLLMIteration(
"temperature": agent.Temperature,
"prompt_cache_key": agent.ID,
}
if useNativeSearch {
llmOpts["native_search"] = true
}
// parseThinkingLevel guarantees ThinkingOff for empty/unknown values,
// so checking != ThinkingOff is sufficient.
if agent.ThinkingLevel != ThinkingOff {
@@ -1387,7 +1422,7 @@ func (al *AgentLoop) runLLMIteration(
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
messages = agent.ContextBuilder.BuildMessages(
newHistory, newSummary, "",
nil, opts.Channel, opts.ChatID,
nil, opts.Channel, opts.ChatID, opts.SenderID, opts.SenderDisplayName,
)
continue
}
@@ -2246,6 +2281,28 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
}
// isNativeSearchProvider reports whether the given LLM provider implements
// NativeSearchCapable and returns true for SupportsNativeSearch.
func isNativeSearchProvider(p providers.LLMProvider) bool {
if ns, ok := p.(providers.NativeSearchCapable); ok {
return ns.SupportsNativeSearch()
}
return false
}
// filterClientWebSearch returns a copy of tools with the client-side
// web_search tool removed. Used when native provider search is preferred.
func filterClientWebSearch(tools []providers.ToolDefinition) []providers.ToolDefinition {
result := make([]providers.ToolDefinition, 0, len(tools))
for _, t := range tools {
if strings.EqualFold(t.Function.Name, "web_search") {
continue
}
result = append(result, t)
}
return result
}
// Helper to extract provider from registry for cleanup
func extractProvider(registry *AgentRegistry) (providers.LLMProvider, bool) {
if registry == nil {
+228 -39
View File
@@ -30,6 +30,28 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
type recordingProvider struct {
lastMessages []providers.Message
}
func (r *recordingProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
r.lastMessages = append([]providers.Message(nil), messages...)
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (r *recordingProvider) GetDefaultModel() string {
return "mock-model"
}
func newTestAgentLoop(
t *testing.T,
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
@@ -54,6 +76,59 @@ func newTestAgentLoop(
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
}
func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "discord",
SenderID: "discord:123",
Sender: bus.SenderInfo{
DisplayName: "Alice",
},
ChatID: "group-1",
Content: "hello",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
}
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(provider.lastMessages) == 0 {
t.Fatal("provider did not receive any messages")
}
systemPrompt := provider.lastMessages[0].Content
wantSender := "## Current Sender\nCurrent sender: Alice (ID: discord:123)"
if !strings.Contains(systemPrompt, wantSender) {
t.Fatalf("system prompt missing sender context %q:\n%s", wantSender, systemPrompt)
}
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
if lastMessage.Role != "user" || lastMessage.Content != "hello" {
t.Fatalf("last provider message = %+v, want unchanged user message", lastMessage)
}
}
func TestRecordLastChannel(t *testing.T) {
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
defer cleanup()
@@ -922,10 +997,25 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "reasoning", "telegram", "")
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if msg, ok := msgBus.SubscribeOutbound(ctx); ok {
t.Fatalf("expected no outbound message, got %+v", msg)
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, got %+v", msg)
}
if msg.Content == "reasoning" {
t.Fatalf("expected no message for empty chatID, got %+v", msg)
}
return
case <-ctx.Done():
t.Log("expected an outbound message, got none within timeout")
return
default:
// Continue to check for message
time.Sleep(5 * time.Millisecond) // Avoid busy loop
}
}
})
@@ -933,9 +1023,7 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
msg, ok := <-msgBus.OutboundChan()
if !ok {
t.Fatal("expected an outbound message")
}
@@ -949,35 +1037,52 @@ func TestHandleReasoning(t *testing.T) {
reasoning := "hello telegram reasoning"
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatal("expected outbound message")
}
for {
select {
case <-ctx.Done():
t.Fatal("expected an outbound message, got none within timeout")
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatal("expected outbound message")
}
if msg.Channel != "telegram" {
t.Fatalf("expected telegram channel message, got %+v", msg)
}
if msg.ChatID != "tg-chat" {
t.Fatalf("expected chatID tg-chat, got %+v", msg)
}
if msg.Content != reasoning {
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
if msg.Channel != "telegram" {
t.Fatalf("expected telegram channel message, got %+v", msg)
}
if msg.ChatID != "tg-chat" {
t.Fatalf("expected chatID tg-chat, got %+v", msg)
}
if msg.Content != reasoning {
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
}
return
}
}
})
t.Run("expired ctx", func(t *testing.T) {
al, msgBus := newLoop(t)
reasoning := "hello telegram reasoning"
ctx, cancel := context.WithCancel(context.Background())
cancel()
al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
if ok {
t.Fatalf("expected no outbound message, got %+v", msg)
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer consumeCancel()
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, but received: %+v", msg)
}
t.Logf("Received unexpected outbound message: %+v", msg)
return
case <-consumeCtx.Done():
t.Fatalf("failed: no message received within timeout")
return
}
}
})
@@ -1017,20 +1122,23 @@ func TestHandleReasoning(t *testing.T) {
// Drain the bus and verify the reasoning message was NOT published
// (it should have been dropped due to timeout).
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer drainCancel()
foundReasoning := false
timeer := time.After(1 * time.Second)
for {
msg, ok := msgBus.SubscribeOutbound(drainCtx)
if !ok {
break
select {
case <-timeer:
t.Logf(
"no reasoning message received after draining bus for 1s, as expected,length=%d",
len(msgBus.OutboundChan()),
)
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
break
}
if msg.Content == "should timeout" {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
}
if msg.Content == "should timeout" {
foundReasoning = true
}
}
if foundReasoning {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
})
}
@@ -1318,3 +1426,84 @@ func TestResolveMediaRefs_MixedImageAndFile(t *testing.T) {
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
}
}
// --- Native search helper tests ---
type nativeSearchProvider struct {
supported bool
}
func (p *nativeSearchProvider) Chat(
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
model string, opts map[string]any,
) (*providers.LLMResponse, error) {
return &providers.LLMResponse{Content: "ok"}, nil
}
func (p *nativeSearchProvider) GetDefaultModel() string { return "test-model" }
func (p *nativeSearchProvider) SupportsNativeSearch() bool { return p.supported }
type plainProvider struct{}
func (p *plainProvider) Chat(
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
model string, opts map[string]any,
) (*providers.LLMResponse, error) {
return &providers.LLMResponse{Content: "ok"}, nil
}
func (p *plainProvider) GetDefaultModel() string { return "test-model" }
func TestIsNativeSearchProvider_Supported(t *testing.T) {
if !isNativeSearchProvider(&nativeSearchProvider{supported: true}) {
t.Fatal("expected true for provider that supports native search")
}
}
func TestIsNativeSearchProvider_NotSupported(t *testing.T) {
if isNativeSearchProvider(&nativeSearchProvider{supported: false}) {
t.Fatal("expected false for provider that does not support native search")
}
}
func TestIsNativeSearchProvider_NoInterface(t *testing.T) {
if isNativeSearchProvider(&plainProvider{}) {
t.Fatal("expected false for provider that does not implement NativeSearchCapable")
}
}
func TestFilterClientWebSearch_RemovesWebSearch(t *testing.T) {
defs := []providers.ToolDefinition{
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "web_search"}},
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
}
result := filterClientWebSearch(defs)
if len(result) != 2 {
t.Fatalf("len(result) = %d, want 2", len(result))
}
for _, td := range result {
if td.Function.Name == "web_search" {
t.Fatal("web_search should be filtered out")
}
}
}
func TestFilterClientWebSearch_NoWebSearch(t *testing.T) {
defs := []providers.ToolDefinition{
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
}
result := filterClientWebSearch(defs)
if len(result) != 2 {
t.Fatalf("len(result) = %d, want 2", len(result))
}
}
func TestFilterClientWebSearch_EmptyInput(t *testing.T) {
result := filterClientWebSearch(nil)
if len(result) != 0 {
t.Fatalf("len(result) = %d, want 0", len(result))
}
}
+2 -1
View File
@@ -6,6 +6,7 @@ import (
"path/filepath"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/fileutil"
)
@@ -39,7 +40,7 @@ func (c *AuthCredential) NeedsRefresh() bool {
}
func authFilePath() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
if home := os.Getenv(config.EnvHome); home != "" {
return filepath.Join(home, "auth.json")
}
home, _ := os.UserHomeDir()
+60 -93
View File
@@ -3,6 +3,7 @@ package bus
import (
"context"
"errors"
"sync"
"sync/atomic"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -17,8 +18,11 @@ type MessageBus struct {
inbound chan InboundMessage
outbound chan OutboundMessage
outboundMedia chan OutboundMediaMessage
done chan struct{}
closed atomic.Bool
closeOnce sync.Once
done chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
func NewMessageBus() *MessageBus {
@@ -30,128 +34,91 @@ func NewMessageBus() *MessageBus {
}
}
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error {
// check bus closed before acquiring wg, to avoid unnecessary wg.Add and potential deadlock
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
// check again,before sending message, to avoid sending to closed channel
select {
case mb.inbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
case <-mb.done:
return ErrBusClosed
default:
}
mb.wg.Add(1)
defer mb.wg.Done()
select {
case ch <- msg:
return nil
case <-ctx.Done():
return ctx.Err()
case <-mb.done:
return ErrBusClosed
}
}
func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) {
select {
case msg, ok := <-mb.inbound:
return msg, ok
case <-mb.done:
return InboundMessage{}, false
case <-ctx.Done():
return InboundMessage{}, false
}
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
return publish(ctx, mb, mb.inbound, msg)
}
func (mb *MessageBus) InboundChan() <-chan InboundMessage {
return mb.inbound
}
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
select {
case mb.outbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
return publish(ctx, mb, mb.outbound, msg)
}
func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) {
select {
case msg, ok := <-mb.outbound:
return msg, ok
case <-mb.done:
return OutboundMessage{}, false
case <-ctx.Done():
return OutboundMessage{}, false
}
func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
return mb.outbound
}
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
select {
case mb.outboundMedia <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
return publish(ctx, mb, mb.outboundMedia, msg)
}
func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) {
select {
case msg, ok := <-mb.outboundMedia:
return msg, ok
case <-mb.done:
return OutboundMediaMessage{}, false
case <-ctx.Done():
return OutboundMediaMessage{}, false
}
func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
return mb.outboundMedia
}
func (mb *MessageBus) Close() {
if mb.closed.CompareAndSwap(false, true) {
mb.closeOnce.Do(func() {
// notify all blocked publishers to exit
close(mb.done)
// Drain buffered channels so messages aren't silently lost.
// Channels are NOT closed to avoid send-on-closed panics from concurrent publishers.
// because every publisher will check mb.closed before acquiring wg
// so we can be sure that new publishers will not be added new messages after this point
mb.closed.Store(true)
// wait for all ongoing Publish calls to finish, ensuring all messages have been sent to channels or exited
mb.wg.Wait()
// close channels safely
close(mb.inbound)
close(mb.outbound)
close(mb.outboundMedia)
// clean up any remaining messages in channels
drained := 0
for {
select {
case <-mb.inbound:
drained++
default:
goto doneInbound
}
for range mb.inbound {
drained++
}
doneInbound:
for {
select {
case <-mb.outbound:
drained++
default:
goto doneOutbound
}
for range mb.outbound {
drained++
}
doneOutbound:
for {
select {
case <-mb.outboundMedia:
drained++
default:
goto doneMedia
}
for range mb.outboundMedia {
drained++
}
doneMedia:
if drained > 0 {
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
"count": drained,
})
}
}
})
}
+35 -17
View File
@@ -24,7 +24,7 @@ func TestPublishConsume(t *testing.T) {
t.Fatalf("PublishInbound failed: %v", err)
}
got, ok := mb.ConsumeInbound(ctx)
got, ok := <-mb.InboundChan()
if !ok {
t.Fatal("ConsumeInbound returned ok=false")
}
@@ -52,7 +52,7 @@ func TestPublishOutboundSubscribe(t *testing.T) {
t.Fatalf("PublishOutbound failed: %v", err)
}
got, ok := mb.SubscribeOutbound(ctx)
got, ok := <-mb.OutboundChan()
if !ok {
t.Fatal("SubscribeOutbound returned ok=false")
}
@@ -108,27 +108,48 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
func TestConsumeInbound_ContextCancel(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
_, ok := mb.ConsumeInbound(ctx)
if ok {
t.Fatal("expected ok=false when context is canceled")
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
select {
case <-ctx.Done():
t.Log("context canceled, as expected")
case msg, ok := <-mb.InboundChan():
if !ok {
t.Fatal("expected ok=false when context is canceled")
}
if msg.Content == "ContextCancel" {
t.Fatalf("expected content 'ContextCancel', got %q", msg.Content)
}
}
}
func TestConsumeInbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
timer := time.AfterFunc(100*time.Millisecond, func() {
mb.Close()
})
_, ok := mb.ConsumeInbound(ctx)
if ok {
t.Fatal("expected ok=false when bus is closed")
select {
case <-timer.C:
t.Log("context canceled, as expected")
case _, ok := <-mb.InboundChan():
if ok {
t.Fatal("expected ok=false when context is canceled")
}
}
}
@@ -136,10 +157,7 @@ func TestSubscribeOutbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, ok := mb.SubscribeOutbound(ctx)
_, ok := <-mb.OutboundChan()
if ok {
t.Fatal("expected ok=false when bus is closed")
}
+33 -4
View File
@@ -29,11 +29,17 @@ import (
"github.com/sipeed/picoclaw/pkg/utils"
)
// errCodeTenantTokenInvalid is the Feishu API error code for an expired/revoked
// tenant_access_token. The Lark SDK's built-in retry does not clear its cache
// on this error, so we do it ourselves.
const errCodeTenantTokenInvalid = 99991663
type FeishuChannel struct {
*channels.BaseChannel
config config.FeishuConfig
client *lark.Client
wsClient *larkws.Client
config config.FeishuConfig
client *lark.Client
wsClient *larkws.Client
tokenCache *tokenCache // custom cache that supports invalidation
botOpenID atomic.Value // stores string; populated lazily for @mention detection
@@ -47,10 +53,12 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
tc := newTokenCache()
ch := &FeishuChannel{
BaseChannel: base,
config: cfg,
client: lark.NewClient(cfg.AppID, cfg.AppSecret),
tokenCache: tc,
client: lark.NewClient(cfg.AppID, cfg.AppSecret, lark.WithTokenCache(tc)),
}
ch.SetOwner(ch)
return ch, nil
@@ -147,6 +155,7 @@ func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, cont
return fmt.Errorf("feishu edit: %w", err)
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
return nil
@@ -186,6 +195,7 @@ func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (str
return "", fmt.Errorf("feishu placeholder send: %w", err)
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
@@ -226,6 +236,7 @@ func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID st
return func() {}, fmt.Errorf("feishu react: %w", err)
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
logger.ErrorCF("feishu", "Reaction API error", map[string]any{
"emoji": chosenEmoji,
"message_id": messageID,
@@ -451,6 +462,7 @@ func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error {
return fmt.Errorf("bot info parse: %w", err)
}
if result.Code != 0 {
c.invalidateTokenOnAuthError(result.Code)
return fmt.Errorf("bot info api error (code=%d)", result.Code)
}
if result.Bot.OpenID == "" {
@@ -593,6 +605,7 @@ func (c *FeishuChannel) downloadResource(
return ""
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
logger.ErrorCF("feishu", "Resource download api error", map[string]any{
"code": resp.Code,
"msg": resp.Msg,
@@ -705,6 +718,7 @@ func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
}
@@ -730,6 +744,7 @@ func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.F
return fmt.Errorf("feishu image upload: %w", err)
}
if !uploadResp.Success() {
c.invalidateTokenOnAuthError(uploadResp.Code)
return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
}
if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil {
@@ -754,6 +769,7 @@ func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.F
return fmt.Errorf("feishu image send: %w", err)
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
return nil
@@ -784,6 +800,7 @@ func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.Fi
return fmt.Errorf("feishu file upload: %w", err)
}
if !uploadResp.Success() {
c.invalidateTokenOnAuthError(uploadResp.Code)
return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
}
if uploadResp.Data == nil || uploadResp.Data.FileKey == nil {
@@ -808,6 +825,7 @@ func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.Fi
return fmt.Errorf("feishu file send: %w", err)
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
return nil
@@ -830,3 +848,14 @@ func extractFeishuSenderID(sender *larkim.EventSender) string {
return ""
}
// invalidateTokenOnAuthError clears the cached tenant_access_token when the
// Feishu API reports it as invalid (99991663), so the next request fetches a
// fresh one. The Lark SDK's built-in retry does not clear the cache, causing
// all API calls to fail until the token naturally expires (~2 hours).
func (c *FeishuChannel) invalidateTokenOnAuthError(code int) {
if code == errCodeTenantTokenInvalid {
c.tokenCache.InvalidateAll()
logger.WarnCF("feishu", "Invalidated cached token due to auth error", nil)
}
}
+52
View File
@@ -0,0 +1,52 @@
package feishu
import (
"context"
"sync"
"time"
)
// tokenCache implements larkcore.Cache with an extra InvalidateAll method.
// This works around a bug in the Lark SDK v3 where the built-in token retry
// loop does not clear stale tokens from cache on auth errors.
type tokenCache struct {
mu sync.RWMutex
store map[string]*tokenEntry
}
type tokenEntry struct {
value string
expireAt time.Time
}
func newTokenCache() *tokenCache {
return &tokenCache{store: make(map[string]*tokenEntry)}
}
func (c *tokenCache) Set(_ context.Context, key, value string, ttl time.Duration) error {
c.mu.Lock()
defer c.mu.Unlock()
c.store[key] = &tokenEntry{value: value, expireAt: time.Now().Add(ttl)}
return nil
}
func (c *tokenCache) Get(_ context.Context, key string) (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
e, ok := c.store[key]
if !ok {
return "", nil
}
if e.expireAt.Before(time.Now()) {
delete(c.store, key)
return "", nil
}
return e.value, nil
}
// InvalidateAll removes all cached tokens, forcing fresh acquisition.
func (c *tokenCache) InvalidateAll() {
c.mu.Lock()
defer c.mu.Unlock()
clear(c.store)
}
+33 -27
View File
@@ -585,7 +585,7 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
func dispatchLoop[M any](
ctx context.Context,
m *Manager,
subscribe func(context.Context) (M, bool),
ch <-chan M,
getChannel func(M) string,
enqueue func(context.Context, *channelWorker, M) bool,
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
@@ -593,35 +593,41 @@ func dispatchLoop[M any](
logger.InfoC("channels", startMsg)
for {
msg, ok := subscribe(ctx)
if !ok {
select {
case <-ctx.Done():
logger.InfoC("channels", stopMsg)
return
}
channel := getChannel(msg)
// Silently skip internal channels
if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[channel]
w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
if !enqueue(ctx, w, msg) {
case msg, ok := <-ch:
if !ok {
logger.InfoC("channels", stopMsg)
return
}
} else if exists {
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
channel := getChannel(msg)
// Silently skip internal channels
if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[channel]
w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
if !enqueue(ctx, w, msg) {
return
}
} else if exists {
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
}
}
}
}
@@ -629,7 +635,7 @@ func dispatchLoop[M any](
func (m *Manager) dispatchOutbound(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutbound,
m.bus.OutboundChan(),
func(msg bus.OutboundMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select {
@@ -649,7 +655,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutboundMedia,
m.bus.OutboundMediaChan(),
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
+14 -6
View File
@@ -34,11 +34,19 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message")
}
if inbound.Metadata["account_id"] != "7750283E123456" {
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
for {
select {
case <-ctx.Done():
t.Fatal("timeout waiting for inbound message")
return
case inbound, ok := <-messageBus.InboundChan():
if !ok {
t.Fatal("expected inbound message")
}
if inbound.Metadata["account_id"] != "7750283E123456" {
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
}
return
}
}
}
@@ -0,0 +1,197 @@
package telegram
import (
"regexp"
"strings"
)
// mdV2SpecialChars are all characters that must be escaped in Telegram MarkdownV2
var mdV2SpecialChars = map[rune]bool{
'*': true,
'_': true,
'[': true,
']': true,
'(': true,
')': true,
'~': true,
'`': true,
'>': true,
'<': true,
'#': true,
'+': true,
'-': true,
'=': true,
'|': true,
'{': true,
'}': true,
'.': true,
'!': true,
'\\': true,
}
// entityPattern describes one Telegram MarkdownV2 inline entity type.
type entityPattern struct {
re *regexp.Regexp
open string
close string
}
// allEntityPatterns lists every recognized entity in priority order
// (longer / more-specific delimiters first so they win over shorter ones).
// Each entry's regex is anchored to find the first occurrence in a string.
var allEntityPatterns = []entityPattern{
// fenced code block — content is completely verbatim
{re: regexp.MustCompile("(?s)```(?:[\\w]*\\n)?[\\s\\S]*?```"), open: "```", close: "```"},
// inline code — content is completely verbatim
{re: regexp.MustCompile("`(?:[^`\\\n]|\\\\.)*`"), open: "`", close: "`"},
// expandable block-quote opener **>…
{re: regexp.MustCompile(`(?m)\*\*>(?:[^\n]*)`), open: "**>", close: ""},
// block-quote line >…
{re: regexp.MustCompile(`(?m)^>(?:[^\n]*)`), open: ">", close: ""},
// custom emoji / timestamp ![…](…) — must come before plain link
{re: regexp.MustCompile(`!\[[^\]]*\]\([^)]*\)`), open: "!", close: ""},
// inline URL / user mention […](…)
{re: regexp.MustCompile(`\[[^\]]*\]\([^)]*\)`), open: "[", close: ""},
// spoiler ||…|| — before single | so it wins
{re: regexp.MustCompile(`\|\|(?:[^|\\\n]|\\.)*\|\|`), open: "||", close: "||"},
// underline __…__ — before single _ so it wins
{re: regexp.MustCompile(`__(?:[^_\\\n]|\\.)*__`), open: "__", close: "__"},
// bold *…*
{re: regexp.MustCompile(`\*(?:[^*\\\n]|\\.)*\*`), open: "*", close: "*"},
// italic _…_
{re: regexp.MustCompile(`_(?:[^_\\\n]|\\.)*_`), open: "_", close: "_"},
// strikethrough ~…~
{re: regexp.MustCompile(`~(?:[^~\\\n]|\\.)*~`), open: "~", close: "~"},
}
// verbatimEntities are entity types whose inner content must never be
// touched (code blocks, URLs, quotes, custom emoji).
// Their content is passed through completely unchanged.
var verbatimEntities = map[string]bool{
"```": true,
"`": true,
"**>": true,
">": true,
"!": true,
"[": true,
}
// markdownToTelegramMarkdownV2 converts a Markdown string into a string safe
// for sending with Telegram's MarkdownV2 parse mode.
//
// Rules:
// - Markdown headings (# … ######) are converted to *bold*.
// - **bold** Markdown syntax is converted to *bold*.
// - Recognized Telegram MarkdownV2 entity spans are preserved; their inner
// content is processed recursively so that nested valid entities are kept
// intact while stray special characters are escaped.
// - All plain-text segments have their MarkdownV2 special characters escaped.
//
// Reference: https://core.telegram.org/bots/api#formatting-options
func markdownToTelegramMarkdownV2(text string) string {
// 1. Convert Markdown headings → *escaped heading text*
text = reHeading.ReplaceAllStringFunc(text, func(match string) string {
sub := reHeading.FindStringSubmatch(match)
if len(sub) < 2 {
return match
}
// The heading content is fresh plain text — escape everything
// including * so the resulting *…* bold span stays valid.
return "*" + escapeMarkdownV2(sub[1]) + "*"
})
// 2. Convert **bold** → *bold*
text = reBoldStar.ReplaceAllString(text, "*$1*")
// 3. Recursively escape the full string.
return processText(text)
}
// processText walks `text`, finds the leftmost / longest matching entity,
// escapes the gap before it, processes the entity (recursing into its inner
// content when appropriate), then continues with the remainder.
func processText(text string) string {
if text == "" {
return ""
}
// Find the leftmost match among all entity patterns.
bestStart := -1
bestEnd := -1
var bestPat *entityPattern
for i := range allEntityPatterns {
p := &allEntityPatterns[i]
loc := p.re.FindStringIndex(text)
if loc == nil {
continue
}
if bestStart == -1 || loc[0] < bestStart ||
(loc[0] == bestStart && (loc[1]-loc[0]) > (bestEnd-bestStart)) {
bestStart = loc[0]
bestEnd = loc[1]
bestPat = p
}
}
if bestPat == nil {
// No entity found — escape everything.
return escapeMarkdownV2(text)
}
var b strings.Builder
// Plain text before the entity.
if bestStart > 0 {
b.WriteString(escapeMarkdownV2(text[:bestStart]))
}
// The matched entity span.
matched := text[bestStart:bestEnd]
if verbatimEntities[bestPat.open] {
// Code blocks, URLs, quotes: pass through completely untouched.
b.WriteString(matched)
} else {
// Inline formatting (bold, italic, underline, strikethrough, spoiler):
// keep the delimiters and recursively process the inner content so that
// nested entities survive but stray specials get escaped.
openLen := len(bestPat.open)
closeLen := len(bestPat.close)
inner := matched[openLen : len(matched)-closeLen]
b.WriteString(bestPat.open)
b.WriteString(processText(inner))
b.WriteString(bestPat.close)
}
// Continue with the remainder of the string.
b.WriteString(processText(text[bestEnd:]))
return b.String()
}
// escapeMarkdownV2 escapes every MarkdownV2 special character in a plain-text
// segment (i.e. a segment that is not part of any recognized entity).
// Already-escaped sequences (backslash + char) are forwarded verbatim to avoid
// double-escaping.
func escapeMarkdownV2(s string) string {
var b strings.Builder
b.Grow(len(s) + 8)
runes := []rune(s)
for i := 0; i < len(runes); i++ {
ch := runes[i]
// Forward an existing escape sequence verbatim.
if ch == '\\' && i+1 < len(runes) {
b.WriteRune(ch)
b.WriteRune(runes[i+1])
i++
continue
}
if mdV2SpecialChars[ch] {
b.WriteByte('\\')
}
b.WriteRune(ch)
}
return b.String()
}
@@ -0,0 +1,68 @@
package telegram
import (
_ "embed"
"testing"
"github.com/stretchr/testify/require"
)
//go:embed testdata/md2_all_formats.txt
var md2AllFormats string
func Test_markdownToTelegramMarkdownV2(t *testing.T) {
cases := []struct {
name string
input string
expected string
}{
{
name: "heading -> bolding",
input: `## HeadingH2 #`,
expected: "*HeadingH2 \\#*",
},
{
name: "strikethrough",
input: "~strikethroughMD~",
expected: "~strikethroughMD~",
},
{
name: "inline URL",
input: "[inline URL](http://www.example.com/)",
expected: "[inline URL](http://www.example.com/)",
},
{
name: "all telegram formats",
input: md2AllFormats,
expected: md2AllFormats,
},
{
name: "empty",
input: "",
expected: "",
},
{
name: "one letter",
input: "o",
expected: "o",
},
{
name: "",
input: "*Last update: ~10 24h*",
expected: "*Last update: \\~10 24h*",
},
{
name: "",
input: "<Market Capitalization>",
expected: "\\<Market Capitalization\\>",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
actual := markdownToTelegramMarkdownV2(tc.input)
require.EqualValues(t, tc.expected, actual)
})
}
}
@@ -0,0 +1,111 @@
package telegram
import (
"fmt"
"strings"
)
func markdownToTelegramHTML(text string) string {
if text == "" {
return ""
}
codeBlocks := extractCodeBlocks(text)
text = codeBlocks.text
inlineCodes := extractInlineCodes(text)
text = inlineCodes.text
text = reHeading.ReplaceAllString(text, "$1")
text = reBlockquote.ReplaceAllString(text, "$1")
text = escapeHTML(text)
text = reLink.ReplaceAllString(text, `<a href="$2">$1</a>`)
text = reBoldStar.ReplaceAllString(text, "<b>$1</b>")
text = reBoldUnder.ReplaceAllString(text, "<b>$1</b>")
text = reItalic.ReplaceAllStringFunc(text, func(s string) string {
match := reItalic.FindStringSubmatch(s)
if len(match) < 2 {
return s
}
return "<i>" + match[1] + "</i>"
})
text = reStrike.ReplaceAllString(text, "<s>$1</s>")
text = reListItem.ReplaceAllString(text, "• ")
for i, code := range inlineCodes.codes {
escaped := escapeHTML(code)
text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("<code>%s</code>", escaped))
}
for i, code := range codeBlocks.codes {
escaped := escapeHTML(code)
text = strings.ReplaceAll(
text,
fmt.Sprintf("\x00CB%d\x00", i),
fmt.Sprintf("<pre><code>%s</code></pre>", escaped),
)
}
return text
}
type codeBlockMatch struct {
text string
codes []string
}
func extractCodeBlocks(text string) codeBlockMatch {
matches := reCodeBlock.FindAllStringSubmatch(text, -1)
codes := make([]string, 0, len(matches))
for _, match := range matches {
codes = append(codes, match[1])
}
i := 0
text = reCodeBlock.ReplaceAllStringFunc(text, func(m string) string {
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
i++
return placeholder
})
return codeBlockMatch{text: text, codes: codes}
}
type inlineCodeMatch struct {
text string
codes []string
}
func extractInlineCodes(text string) inlineCodeMatch {
matches := reInlineCode.FindAllStringSubmatch(text, -1)
codes := make([]string, 0, len(matches))
for _, match := range matches {
codes = append(codes, match[1])
}
i := 0
text = reInlineCode.ReplaceAllStringFunc(text, func(m string) string {
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
i++
return placeholder
})
return inlineCodeMatch{text: text, codes: codes}
}
func escapeHTML(text string) string {
text = strings.ReplaceAll(text, "&", "&amp;")
text = strings.ReplaceAll(text, "<", "&lt;")
text = strings.ReplaceAll(text, ">", "&gt;")
return text
}
+129 -128
View File
@@ -3,6 +3,7 @@ package telegram
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"os"
@@ -26,7 +27,7 @@ import (
)
var (
reHeading = regexp.MustCompile(`^#{1,6}\s+(.+)$`)
reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+([^\n]+)`)
reBlockquote = regexp.MustCompile(`^>\s*(.*)$`)
reLink = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
reBoldStar = regexp.MustCompile(`\*\*(.+?)\*\*`)
@@ -169,6 +170,8 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return channels.ErrNotRunning
}
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
if err != nil {
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
@@ -187,22 +190,65 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
chunk := queue[0]
queue = queue[1:]
htmlContent := markdownToTelegramHTML(chunk)
content := parseContent(chunk, useMarkdownV2)
if len([]rune(htmlContent)) > 4096 {
ratio := float64(len([]rune(chunk))) / float64(len([]rune(htmlContent)))
if len([]rune(content)) > 4096 {
runeChunk := []rune(chunk)
ratio := float64(len(runeChunk)) / float64(len([]rune(content)))
smallerLen := int(float64(4096) * ratio * 0.95) // 5% safety margin
if smallerLen < 100 {
smallerLen = 100
// Guarantee progress: if estimated length is >= chunk length, force it smaller
if smallerLen >= len(runeChunk) {
smallerLen = len(runeChunk) - 1
}
// Push sub-chunks back to the front of the queue for
// re-validation instead of sending them blindly.
if smallerLen <= 0 {
if err := c.sendChunk(ctx, sendChunkParams{
chatID: chatID,
threadID: threadID,
content: content,
replyToID: replyToID,
mdFallback: chunk,
useMarkdownV2: useMarkdownV2,
}); err != nil {
return err
}
replyToID = ""
continue
}
// Use the estimated smaller length as a guide for SplitMessage.
// SplitMessage will find natural break points (newlines/spaces) and respect code blocks.
subChunks := channels.SplitMessage(chunk, smallerLen)
queue = append(subChunks, queue...)
// Safety fallback: If SplitMessage failed to shorten the chunk, force a manual hard split.
if len(subChunks) == 1 && subChunks[0] == chunk {
part1 := string(runeChunk[:smallerLen])
part2 := string(runeChunk[smallerLen:])
subChunks = []string{part1, part2}
}
// Filter out empty chunks to avoid sending empty messages to Telegram.
nonEmpty := make([]string, 0, len(subChunks))
for _, s := range subChunks {
if s != "" {
nonEmpty = append(nonEmpty, s)
}
}
// Push sub-chunks back to the front of the queue
queue = append(nonEmpty, queue...)
continue
}
if err := c.sendHTMLChunk(ctx, chatID, threadID, htmlContent, chunk, replyToID); err != nil {
if err := c.sendChunk(ctx, sendChunkParams{
chatID: chatID,
threadID: threadID,
content: content,
replyToID: replyToID,
mdFallback: chunk,
useMarkdownV2: useMarkdownV2,
}); err != nil {
return err
}
// Only the first chunk should be a reply; subsequent chunks are normal messages.
@@ -212,17 +258,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return nil
}
// sendHTMLChunk sends a single HTML message, falling back to the original
// markdown as plain text on parse failure so users never see raw HTML tags.
func (c *TelegramChannel) sendHTMLChunk(
ctx context.Context, chatID int64, threadID int, htmlContent, mdFallback string, replyToID string,
) error {
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
tgMsg.ParseMode = telego.ModeHTML
tgMsg.MessageThreadID = threadID
type sendChunkParams struct {
chatID int64
threadID int
content string
replyToID string
mdFallback string
useMarkdownV2 bool
}
if replyToID != "" {
if mid, parseErr := strconv.Atoi(replyToID); parseErr == nil {
// sendChunk sends a single HTML/MarkdownV2 message, falling back to the original
// markdown as plain text on parse failure so users never see raw HTML/MarkdownV2 tags.
func (c *TelegramChannel) sendChunk(
ctx context.Context,
params sendChunkParams,
) error {
tgMsg := tu.Message(tu.ID(params.chatID), params.content)
tgMsg.MessageThreadID = params.threadID
if params.useMarkdownV2 {
tgMsg.WithParseMode(telego.ModeMarkdownV2)
} else {
tgMsg.WithParseMode(telego.ModeHTML)
}
if params.replyToID != "" {
if mid, parseErr := strconv.Atoi(params.replyToID); parseErr == nil {
tgMsg.ReplyParameters = &telego.ReplyParameters{
MessageID: mid,
}
@@ -230,15 +290,15 @@ func (c *TelegramChannel) sendHTMLChunk(
}
if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil {
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{
"error": err.Error(),
})
tgMsg.Text = mdFallback
logParseFailed(err, params.useMarkdownV2)
tgMsg.Text = params.mdFallback
tgMsg.ParseMode = ""
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
return fmt.Errorf("telegram send: %w", channels.ErrTemporary)
}
}
return nil
}
@@ -279,6 +339,7 @@ func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(
// EditMessage implements channels.MessageEditor.
func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
cid, _, err := parseTelegramChatID(chatID)
if err != nil {
return err
@@ -287,10 +348,19 @@ func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messag
if err != nil {
return err
}
htmlContent := markdownToTelegramHTML(content)
editMsg := tu.EditMessageText(tu.ID(cid), mid, htmlContent)
editMsg.ParseMode = telego.ModeHTML
parsedContent := parseContent(content, useMarkdownV2)
editMsg := tu.EditMessageText(tu.ID(cid), mid, parsedContent)
if useMarkdownV2 {
editMsg.WithParseMode(telego.ModeMarkdownV2)
} else {
editMsg.WithParseMode(telego.ModeHTML)
}
_, err = c.bot.EditMessageText(ctx, editMsg)
if err != nil {
logParseFailed(err, useMarkdownV2)
_, err = c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(cid), mid, content))
}
return err
}
@@ -367,6 +437,20 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
Caption: part.Caption,
}
_, err = c.bot.SendPhoto(ctx, params)
if err != nil && strings.Contains(err.Error(), "PHOTO_INVALID_DIMENSIONS") {
if _, seekErr := file.Seek(0, io.SeekStart); seekErr != nil {
file.Close()
return fmt.Errorf("telegram rewind media after photo failure: %w", channels.ErrTemporary)
}
docParams := &telego.SendDocumentParams{
ChatID: tu.ID(chatID),
MessageThreadID: threadID,
Document: telego.InputFile{File: file},
Caption: part.Caption,
}
_, err = c.bot.SendDocument(ctx, docParams)
}
case "audio":
params := &telego.SendAudioParams{
ChatID: tu.ID(chatID),
@@ -624,6 +708,14 @@ func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string)
return c.downloadFileWithInfo(file, ext)
}
func parseContent(text string, useMarkdownV2 bool) string {
if useMarkdownV2 {
return markdownToTelegramMarkdownV2(text)
}
return markdownToTelegramHTML(text)
}
// parseTelegramChatID splits "chatID/threadID" into its components.
// Returns threadID=0 when no "/" is present (non-forum messages).
func parseTelegramChatID(chatID string) (int64, int, error) {
@@ -643,109 +735,18 @@ func parseTelegramChatID(chatID string) (int64, int, error) {
return cid, tid, nil
}
func markdownToTelegramHTML(text string) string {
if text == "" {
return ""
func logParseFailed(err error, useMarkdownV2 bool) {
parsingName := "HTML"
if useMarkdownV2 {
parsingName = "MarkdownV2"
}
codeBlocks := extractCodeBlocks(text)
text = codeBlocks.text
inlineCodes := extractInlineCodes(text)
text = inlineCodes.text
text = reHeading.ReplaceAllString(text, "$1")
text = reBlockquote.ReplaceAllString(text, "$1")
text = escapeHTML(text)
text = reLink.ReplaceAllString(text, `<a href="$2">$1</a>`)
text = reBoldStar.ReplaceAllString(text, "<b>$1</b>")
text = reBoldUnder.ReplaceAllString(text, "<b>$1</b>")
text = reItalic.ReplaceAllStringFunc(text, func(s string) string {
match := reItalic.FindStringSubmatch(s)
if len(match) < 2 {
return s
}
return "<i>" + match[1] + "</i>"
})
text = reStrike.ReplaceAllString(text, "<s>$1</s>")
text = reListItem.ReplaceAllString(text, "• ")
for i, code := range inlineCodes.codes {
escaped := escapeHTML(code)
text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("<code>%s</code>", escaped))
}
for i, code := range codeBlocks.codes {
escaped := escapeHTML(code)
text = strings.ReplaceAll(
text,
fmt.Sprintf("\x00CB%d\x00", i),
fmt.Sprintf("<pre><code>%s</code></pre>", escaped),
)
}
return text
}
type codeBlockMatch struct {
text string
codes []string
}
func extractCodeBlocks(text string) codeBlockMatch {
matches := reCodeBlock.FindAllStringSubmatch(text, -1)
codes := make([]string, 0, len(matches))
for _, match := range matches {
codes = append(codes, match[1])
}
i := 0
text = reCodeBlock.ReplaceAllStringFunc(text, func(m string) string {
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
i++
return placeholder
})
return codeBlockMatch{text: text, codes: codes}
}
type inlineCodeMatch struct {
text string
codes []string
}
func extractInlineCodes(text string) inlineCodeMatch {
matches := reInlineCode.FindAllStringSubmatch(text, -1)
codes := make([]string, 0, len(matches))
for _, match := range matches {
codes = append(codes, match[1])
}
i := 0
text = reInlineCode.ReplaceAllStringFunc(text, func(m string) string {
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
i++
return placeholder
})
return inlineCodeMatch{text: text, codes: codes}
}
func escapeHTML(text string) string {
text = strings.ReplaceAll(text, "&", "&amp;")
text = strings.ReplaceAll(text, "<", "&lt;")
text = strings.ReplaceAll(text, ">", "&gt;")
return text
logger.ErrorCF("telegram",
fmt.Sprintf("%s parse failed, falling back to plain text", parsingName),
map[string]any{
"error": err.Error(),
},
)
}
// isBotMentioned checks if the bot is mentioned in the message via entities.
@@ -3,7 +3,6 @@ package telegram
import (
"context"
"testing"
"time"
"github.com/mymmrac/telego"
@@ -36,10 +35,7 @@ func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
@@ -108,22 +108,24 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
select {
case <-ctx.Done():
if tc.wantForwarded {
t.Fatal("timeout waiting for message to be forwarded")
return
}
if inbound.Content != tc.wantContent {
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
case inbound, ok := <-messageBus.InboundChan():
if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Content != tc.wantContent {
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
}
return
}
return
}
if ok {
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
}
})
}
+196 -15
View File
@@ -4,9 +4,11 @@ import (
"context"
"encoding/json"
"errors"
"io"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
@@ -15,6 +17,8 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
const testToken = "1234567890:aaaabbbbaaaabbbbaaaabbbbaaaabbbbccc"
@@ -38,8 +42,20 @@ func (s *stubCaller) Call(ctx context.Context, url string, data *ta.RequestData)
// stubConstructor implements ta.RequestConstructor for testing.
type stubConstructor struct{}
type multipartCall struct {
Parameters map[string]string
FileSizes map[string]int
}
func (s *stubConstructor) JSONRequest(parameters any) (*ta.RequestData, error) {
return &ta.RequestData{}, nil
b, err := json.Marshal(parameters)
if err != nil {
return nil, err
}
return &ta.RequestData{
ContentType: "application/json",
BodyRaw: b,
}, nil
}
func (s *stubConstructor) MultipartRequest(
@@ -49,6 +65,36 @@ func (s *stubConstructor) MultipartRequest(
return &ta.RequestData{}, nil
}
type multipartRecordingConstructor struct {
stubConstructor
calls []multipartCall
}
func (s *multipartRecordingConstructor) MultipartRequest(
parameters map[string]string,
files map[string]ta.NamedReader,
) (*ta.RequestData, error) {
call := multipartCall{
Parameters: make(map[string]string, len(parameters)),
FileSizes: make(map[string]int, len(files)),
}
for k, v := range parameters {
call.Parameters[k] = v
}
for field, file := range files {
if file == nil {
continue
}
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
call.FileSizes[field] = len(data)
}
s.calls = append(s.calls, call)
return &ta.RequestData{}, nil
}
// successResponse returns a ta.Response that telego will treat as a successful SendMessage.
func successResponse(t *testing.T) *ta.Response {
t.Helper()
@@ -60,11 +106,19 @@ func successResponse(t *testing.T) *ta.Response {
// newTestChannel creates a TelegramChannel with a mocked bot for unit testing.
func newTestChannel(t *testing.T, caller *stubCaller) *TelegramChannel {
return newTestChannelWithConstructor(t, caller, &stubConstructor{})
}
func newTestChannelWithConstructor(
t *testing.T,
caller *stubCaller,
constructor ta.RequestConstructor,
) *TelegramChannel {
t.Helper()
bot, err := telego.NewBot(testToken,
telego.WithAPICaller(caller),
telego.WithRequestConstructor(&stubConstructor{}),
telego.WithRequestConstructor(constructor),
telego.WithDiscardLogger(),
)
require.NoError(t, err)
@@ -78,9 +132,96 @@ func newTestChannel(t *testing.T, caller *stubCaller) *TelegramChannel {
BaseChannel: base,
bot: bot,
chatIDs: make(map[string]int64),
config: config.DefaultConfig(),
}
}
func TestSendMedia_ImageFallbacksToDocumentOnInvalidDimensions(t *testing.T) {
constructor := &multipartRecordingConstructor{}
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
switch {
case strings.Contains(url, "sendPhoto"):
return nil, errors.New(`api: 400 "Bad Request: PHOTO_INVALID_DIMENSIONS"`)
case strings.Contains(url, "sendDocument"):
return successResponse(t), nil
default:
t.Fatalf("unexpected API call: %s", url)
return nil, nil
}
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
localPath := filepath.Join(tmpDir, "woodstock-en-10s.png")
content := []byte("fake-png-content")
require.NoError(t, os.WriteFile(localPath, content, 0o644))
ref, err := store.Store(
localPath,
media.MediaMeta{Filename: "woodstock-en-10s.png", ContentType: "image/png"},
"scope-1",
)
require.NoError(t, err)
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: []bus.MediaPart{{
Type: "image",
Ref: ref,
Caption: "caption",
}},
})
require.NoError(t, err)
require.Len(t, caller.calls, 2)
assert.Contains(t, caller.calls[0].URL, "sendPhoto")
assert.Contains(t, caller.calls[1].URL, "sendDocument")
require.Len(t, constructor.calls, 2)
assert.Equal(t, len(content), constructor.calls[0].FileSizes["photo"])
assert.Equal(t, len(content), constructor.calls[1].FileSizes["document"])
assert.Equal(t, "caption", constructor.calls[1].Parameters["caption"])
}
func TestSendMedia_ImageNonDimensionErrorDoesNotFallback(t *testing.T) {
constructor := &multipartRecordingConstructor{}
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return nil, errors.New("api: 500 \"server exploded\"")
},
}
ch := newTestChannelWithConstructor(t, caller, constructor)
store := media.NewFileMediaStore()
ch.SetMediaStore(store)
tmpDir := t.TempDir()
localPath := filepath.Join(tmpDir, "image.png")
require.NoError(t, os.WriteFile(localPath, []byte("fake-png-content"), 0o644))
ref, err := store.Store(localPath, media.MediaMeta{Filename: "image.png", ContentType: "image/png"}, "scope-1")
require.NoError(t, err)
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
ChatID: "12345",
Parts: []bus.MediaPart{{
Type: "image",
Ref: ref,
}},
})
require.Error(t, err)
assert.ErrorIs(t, err, channels.ErrTemporary)
require.Len(t, caller.calls, 1)
assert.Contains(t, caller.calls[0].URL, "sendPhoto")
require.Len(t, constructor.calls, 1)
assert.NotContains(t, caller.calls[0].URL, "sendDocument")
}
func TestSend_EmptyContent(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
@@ -235,6 +376,55 @@ func TestSend_MarkdownShortButHTMLLong_MultipleCalls(t *testing.T) {
)
}
func TestSend_HTMLOverflow_WordBoundary(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
// We want to force a split near index ~2600 while keeping markdown length <= 4000.
// Prefix of 430 bold units (6 chars each) = 2580 chars.
// Expansion per unit is +3 chars when converted to HTML, so 2580 + 430*3 = 3870.
prefix := strings.Repeat("**a** ", 430)
targetWord := "TARGETWORDTHATSTAYSTOGETHER"
// Suffix of 230 bold units (6 chars each) = 1380 chars.
// Total markdown length: 2580 (prefix) + 27 (target word) + 1380 (suffix) = 3987 <= 4000.
// HTML expansion adds ~3 chars per bold unit: (430 + 230)*3 = 1980 extra chars,
// so total HTML length comfortably exceeds 4096.
suffix := strings.Repeat(" **b**", 230)
content := prefix + targetWord + suffix
// Ensure the test content matches the intended boundary conditions.
assert.LessOrEqual(t, len([]rune(content)), 4000, "markdown content must not exceed chunk size for this test")
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "123456",
Content: content,
})
assert.NoError(t, err)
foundFullWord := false
for i, call := range caller.calls {
var params map[string]any
err := json.Unmarshal(call.Data.BodyRaw, &params)
require.NoError(t, err)
text, _ := params["text"].(string)
hasWord := strings.Contains(text, targetWord)
t.Logf("Chunk %d length: %d, contains target word: %v", i, len(text), hasWord)
if hasWord {
foundFullWord = true
break
}
}
assert.True(t, foundFullWord, "The target word should not be split between chunks")
}
func TestSend_NotRunning(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
@@ -355,10 +545,7 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok, "expected inbound message")
// Composite chatID should include thread ID
@@ -397,10 +584,7 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
// Plain chatID without thread suffix
@@ -443,10 +627,7 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
// chatID should NOT include thread suffix for non-forum groups
+31
View File
@@ -0,0 +1,31 @@
*bold \*text*
_italic \*text_
__underline__
~strikethrough~
||spoiler||
*bold _italic bold ~italic bold strikethrough ||italic bold strikethrough spoiler||~ __underline italic bold___ bold*
[inline URL](http://www.example.com/)
[inline mention of a user](tg://user?id=123456789)
![👍](tg://emoji?id=5368324170671202286)
![22:45 tomorrow](tg://time?unix=1647531900&format=wDT)
![22:45 tomorrow](tg://time?unix=1647531900&format=t)
![22:45 tomorrow](tg://time?unix=1647531900&format=r)
![22:45 tomorrow](tg://time?unix=1647531900)
`inline fixed-width code`
```
pre-formatted fixed-width code block
```
```python
pre-formatted fixed-width code block written in the Python programming language
```
>Block quotation started
>Block quotation continued
>Block quotation continued
>Block quotation continued
>The last line of the block quotation
**>The expandable block quotation started right after the previous block quotation
>It is separated from the previous block quotation by an empty bold entity
>Expandable block quotation continued
>Hidden by default part of the expandable block quotation started
>Expandable block quotation continued
>The last line of the expandable block quotation with the expandability mark||
@@ -3,7 +3,6 @@ package whatsapp
import (
"context"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -25,10 +24,7 @@ func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T
"content": "/help",
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
@@ -43,14 +43,19 @@ func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "whatsapp_native" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for message to be forwarded")
return
case inbound, ok := <-messageBus.InboundChan():
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "whatsapp_native" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
}
}
+18 -5
View File
@@ -312,6 +312,7 @@ type TelegramConfig struct {
Typing TypingConfig `json:"typing,omitempty"`
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_TELEGRAM_REASONING_CHANNEL_ID"`
UseMarkdownV2 bool `json:"use_markdown_v2" env:"PICOCLAW_CHANNELS_TELEGRAM_USE_MARKDOWN_V2"`
}
type FeishuConfig struct {
@@ -532,6 +533,7 @@ type ProvidersConfig struct {
Minimax ProviderConfig `json:"minimax"`
LongCat ProviderConfig `json:"longcat"`
ModelScope ProviderConfig `json:"modelscope"`
Novita ProviderConfig `json:"novita"`
}
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
@@ -560,7 +562,8 @@ func (p ProvidersConfig) IsEmpty() bool {
p.Avian.APIKey == "" && p.Avian.APIBase == "" &&
p.Minimax.APIKey == "" && p.Minimax.APIBase == "" &&
p.LongCat.APIKey == "" && p.LongCat.APIBase == "" &&
p.ModelScope.APIKey == "" && p.ModelScope.APIBase == ""
p.ModelScope.APIKey == "" && p.ModelScope.APIBase == "" &&
p.Novita.APIKey == "" && p.Novita.APIBase == ""
}
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
@@ -590,7 +593,9 @@ type OpenAIProviderConfig struct {
// ModelConfig represents a model-centric provider configuration.
// It allows adding new providers (especially OpenAI-compatible ones) via configuration only.
// The model field uses protocol prefix format: [protocol/]model-identifier
// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
// Supported protocols include openai, anthropic, antigravity, claude-cli,
// codex-cli, github-copilot, and named OpenAI-compatible protocols such as
// groq, deepseek, modelscope, and novita.
// Default protocol is "openai" if no prefix is specified.
type ModelConfig struct {
// Required fields
@@ -694,10 +699,18 @@ type WebToolsConfig struct {
Perplexity PerplexityConfig ` json:"perplexity"`
SearXNG SearXNGConfig ` json:"searxng"`
GLMSearch GLMSearchConfig ` json:"glm_search"`
// PreferNative controls whether to use provider-native web search when
// the active LLM supports it (e.g. OpenAI web_search_preview). When true,
// the client-side web_search tool is hidden to avoid duplicate search surfaces,
// and the provider's built-in search is used instead. Falls back to client-side
// search when the provider does not support native search.
PreferNative bool `json:"prefer_native" env:"PICOCLAW_TOOLS_WEB_PREFER_NATIVE"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
Format string `json:"format,omitempty" env:"PICOCLAW_TOOLS_WEB_FORMAT"`
PrivateHostWhitelist FlexibleStringSlice `json:"private_host_whitelist,omitempty" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"`
}
type CronToolsConfig struct {
@@ -1030,7 +1043,7 @@ func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) {
}
// Multiple configs - use round-robin for load balancing
idx := rrCounter.Add(1) % uint64(len(matches))
idx := (rrCounter.Add(1) - 1) % uint64(len(matches))
return &matches[idx], nil
}
+55
View File
@@ -77,6 +77,22 @@ func TestAgentModelConfig_MarshalObject(t *testing.T) {
}
}
func TestProvidersConfig_IsEmpty(t *testing.T) {
var empty ProvidersConfig
if !empty.IsEmpty() {
t.Fatal("empty ProvidersConfig should report empty")
}
novita := ProvidersConfig{
Novita: ProviderConfig{
APIKey: "test-key",
},
}
if novita.IsEmpty() {
t.Fatal("ProvidersConfig with novita settings should not report empty")
}
}
func TestAgentConfig_FullParse(t *testing.T) {
jsonData := `{
"agents": {
@@ -401,6 +417,45 @@ func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) {
}
}
func TestDefaultConfig_WebPreferNativeEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Tools.Web.PreferNative {
t.Fatal("DefaultConfig().Tools.Web.PreferNative should be true")
}
}
func TestLoadConfig_WebPreferNativeDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"tools":{"web":{"enabled":true}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if !cfg.Tools.Web.PreferNative {
t.Fatal("PreferNative should remain true when unset in config file")
}
}
func TestLoadConfig_WebPreferNativeCanBeDisabled(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"tools":{"web":{"prefer_native":false}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Tools.Web.PreferNative {
t.Fatal("PreferNative should be false when disabled in config file")
}
}
func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Tools.Exec.AllowRemote {
+4 -1
View File
@@ -15,7 +15,7 @@ func DefaultConfig() *Config {
// Determine the base path for the workspace.
// Priority: $PICOCLAW_HOME > ~/.picoclaw
var homePath string
if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" {
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
homePath = picoclawHome
} else {
userHome, _ := os.UserHomeDir()
@@ -59,6 +59,7 @@ func DefaultConfig() *Config {
Enabled: true,
Text: "Thinking... 💭",
},
UseMarkdownV2: false,
},
Feishu: FeishuConfig{
Enabled: false,
@@ -412,8 +413,10 @@ func DefaultConfig() *Config {
ToolConfig: ToolConfig{
Enabled: true,
},
PreferNative: true,
Proxy: "",
FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
Format: "plaintext",
Brave: BraveConfig{
Enabled: false,
APIKey: "",
+37
View File
@@ -0,0 +1,37 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package config
// Runtime environment variable keys for the picoclaw process.
// These control the location of files and binaries at runtime and are read
// directly via os.Getenv / os.LookupEnv. All picoclaw-specific keys use the
// PICOCLAW_ prefix. Reference these constants instead of inline string
// literals to keep all supported knobs visible in one place and to prevent
// typos.
const (
// EnvHome overrides the base directory for all picoclaw data
// (config, workspace, skills, auth store, …).
// Default: ~/.picoclaw
EnvHome = "PICOCLAW_HOME"
// EnvConfig overrides the full path to the JSON config file.
// Default: $PICOCLAW_HOME/config.json
EnvConfig = "PICOCLAW_CONFIG"
// EnvBuiltinSkills overrides the directory from which built-in
// skills are loaded.
// Default: <cwd>/skills
EnvBuiltinSkills = "PICOCLAW_BUILTIN_SKILLS"
// EnvBinary overrides the path to the picoclaw executable.
// Used by the web launcher when spawning the gateway subprocess.
// Default: resolved from the same directory as the current executable.
EnvBinary = "PICOCLAW_BINARY"
// EnvGatewayHost overrides the host address for the gateway server.
// Default: "127.0.0.1"
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
)
+30
View File
@@ -80,6 +80,36 @@ func TestGetModelConfig_RoundRobin(t *testing.T) {
}
}
func TestGetModelConfig_RoundRobinStartsFromFirstMatch(t *testing.T) {
rrCounter.Store(0)
cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
{ModelName: "lb-model", Model: "openai/gpt-4o-2", APIKey: "key2"},
{ModelName: "lb-model", Model: "openai/gpt-4o-3", APIKey: "key3"},
},
}
wantOrder := []string{
"openai/gpt-4o-1",
"openai/gpt-4o-2",
"openai/gpt-4o-3",
"openai/gpt-4o-1",
"openai/gpt-4o-2",
}
for i, want := range wantOrder {
result, err := cfg.GetModelConfig("lb-model")
if err != nil {
t.Fatalf("GetModelConfig() call %d error = %v", i, err)
}
if result.Model != want {
t.Fatalf("GetModelConfig() call %d model = %q, want %q", i, result.Model, want)
}
}
}
func TestGetModelConfig_Concurrent(t *testing.T) {
cfg := &Config{
ModelList: []ModelConfig{
+11 -4
View File
@@ -66,6 +66,14 @@ var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
// SSHKeyPathEnvVar is the environment variable that specifies the path to the
// SSH private key used for enc:// credential encryption and decryption.
const SSHKeyPathEnvVar = "PICOCLAW_SSH_KEY_PATH"
// picoclawHome is a package-local copy of config.EnvHome. It is kept here to
// avoid a circular import between pkg/credential and pkg/config.
const picoclawHome = "PICOCLAW_HOME"
const (
fileScheme = "file://"
encScheme = "enc://"
@@ -73,7 +81,6 @@ const (
saltLen = 16
nonceLen = 12
keyLen = 32
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
)
// Resolver resolves raw credential strings for model_list api_key fields.
@@ -248,14 +255,14 @@ func allowedSSHKeyPath(path string) bool {
clean := filepath.Clean(path)
// Exact match with PICOCLAW_SSH_KEY_PATH.
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
if envPath, ok := os.LookupEnv(SSHKeyPathEnvVar); ok && envPath != "" {
if clean == filepath.Clean(envPath) {
return true
}
}
// Within PICOCLAW_HOME.
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
if picoHome := os.Getenv(picoclawHome); picoHome != "" {
if isWithinDir(clean, picoHome) {
return true
}
@@ -316,7 +323,7 @@ func pickSSHKeyPath(override string) string {
if override != "" {
return override
}
if p, ok := os.LookupEnv(sshKeyEnv); ok {
if p, ok := os.LookupEnv(SSHKeyPathEnvVar); ok {
return p // respect explicit setting, even if ""
}
return findDefaultSSHKey()
+66 -11
View File
@@ -65,6 +65,7 @@ type CronService struct {
mu sync.RWMutex
running bool
stopChan chan struct{}
wakeChan chan struct{}
gronx *gronx.Gronx
}
@@ -73,6 +74,7 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
storePath: storePath,
onJob: onJob,
gronx: gronx.New(),
wakeChan: make(chan struct{}),
}
// Initialize and load store on creation
cs.loadStore()
@@ -97,6 +99,9 @@ func (cs *CronService) Start() error {
}
cs.stopChan = make(chan struct{})
if cs.wakeChan == nil {
cs.wakeChan = make(chan struct{})
}
cs.running = true
go cs.runLoop(cs.stopChan)
@@ -119,14 +124,47 @@ func (cs *CronService) Stop() {
}
func (cs *CronService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
<-timer.C
}
defer timer.Stop()
for {
// every loop, recalculate the next wake time
cs.mu.RLock()
nextWake := cs.getNextWakeMS()
cs.mu.RUnlock()
var delay time.Duration
now := time.Now().UnixMilli()
if nextWake == nil {
// no jobs, sleep for a long time (or until a new job is added)
delay = time.Hour
} else {
diff := *nextWake - now
if diff <= 0 {
delay = 0
} else {
delay = time.Duration(diff) * time.Millisecond
}
}
timer.Reset(delay)
select {
case <-stopChan:
return
case <-ticker.C:
case <-cs.wakeChan: // wake on new job or update
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
continue
case <-timer.C:
cs.checkJobs()
}
}
@@ -264,22 +302,19 @@ func (cs *CronService) executeJobByID(jobID string) {
}
func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int64 {
if schedule.Kind == "at" {
switch schedule.Kind {
case "at":
if schedule.AtMS != nil && *schedule.AtMS > nowMS {
return schedule.AtMS
}
return nil
}
if schedule.Kind == "every" {
case "every":
if schedule.EveryMS == nil || *schedule.EveryMS <= 0 {
return nil
}
next := nowMS + *schedule.EveryMS
return &next
}
if schedule.Kind == "cron" {
case "cron":
if schedule.Expr == "" {
return nil
}
@@ -294,9 +329,19 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6
nextMS := nextTime.UnixMilli()
return &nextMS
default:
log.Printf("[cron] unknown schedule kind '%s'", schedule.Kind)
return nil
}
}
return nil
// wake up the loop to re-evaluate next wake time immediately (e.g. after add/update/remove jobs)
func (cs *CronService) notify() {
select {
case cs.wakeChan <- struct{}{}:
default:
// if the channel is full, it means the loop will wake up soon anyway, so we can skip sending
}
}
func (cs *CronService) recomputeNextRuns() {
@@ -400,6 +445,8 @@ func (cs *CronService) AddJob(
return nil, err
}
cs.notify()
return &job, nil
}
@@ -411,6 +458,9 @@ func (cs *CronService) UpdateJob(job *CronJob) error {
if cs.store.Jobs[i].ID == job.ID {
cs.store.Jobs[i] = *job
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
cs.notify()
return cs.saveStoreUnsafe()
}
}
@@ -441,6 +491,8 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool {
}
}
cs.notify()
return removed
}
@@ -463,6 +515,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob {
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store after enable: %v", err)
}
cs.notify()
return job
}
}
+199
View File
@@ -1,10 +1,13 @@
package cron
import (
"fmt"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
)
func TestSaveStore_FilePermissions(t *testing.T) {
@@ -36,3 +39,199 @@ func TestSaveStore_FilePermissions(t *testing.T) {
func int64Ptr(v int64) *int64 {
return &v
}
func setupService(handler JobHandler) (*CronService, string) {
tmpFile := fmt.Sprintf("test_cron_%d.json", time.Now().UnixNano())
cs := NewCronService(tmpFile, handler)
return cs, tmpFile
}
func TestCronService_CRUD(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)
// Test AddJob
at := time.Now().Add(time.Hour).UnixMilli()
job, err := cs.AddJob("Task1", CronSchedule{Kind: "at", AtMS: &at}, "msg", true, "ch", "to")
if err != nil || job.ID == "" {
t.Fatalf("AddJob failed: %v", err)
}
// Test ListJobs
if len(cs.ListJobs(true)) != 1 {
t.Error("ListJobs should return 1 job")
}
// Test UpdateJob
job.Name = "UpdatedName"
err = cs.UpdateJob(job)
if err != nil || cs.store.Jobs[0].Name != "UpdatedName" {
t.Error("UpdateJob failed")
}
// Test EnableJob
cs.EnableJob(job.ID, false)
if cs.store.Jobs[0].Enabled != false || cs.store.Jobs[0].State.NextRunAtMS != nil {
t.Error("EnableJob(false) failed to clear state")
}
// Test RemoveJob
removed := cs.RemoveJob(job.ID)
if !removed || len(cs.store.Jobs) != 0 {
t.Error("RemoveJob failed")
}
}
// 2. Test Cron Expression Calculation Logic
func TestCronService_ComputeNextRun(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)
now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC).UnixMilli()
tests := []struct {
name string
schedule CronSchedule
wantNil bool
}{
{"Valid Cron", CronSchedule{Kind: "cron", Expr: "0 * * * *"}, false},
{"Invalid Cron", CronSchedule{Kind: "cron", Expr: "invalid"}, true},
{"Every MS", CronSchedule{Kind: "every", EveryMS: int64Ptr(5000)}, false},
{"At Future", CronSchedule{Kind: "at", AtMS: int64Ptr(now + 1000)}, false},
{"At Past", CronSchedule{Kind: "at", AtMS: int64Ptr(now - 1000)}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := cs.computeNextRun(&tt.schedule, now)
if (got == nil) != tt.wantNil {
t.Errorf("%s: got %v, wantNil %v", tt.name, got, tt.wantNil)
}
})
}
}
// 3. Test Execution Flow
func TestCronService_ExecutionFlow(t *testing.T) {
var mu sync.Mutex
executedJobs := make(map[string]bool)
handler := func(job *CronJob) (string, error) {
mu.Lock()
executedJobs[job.ID] = true
mu.Unlock()
return "ok", nil
}
cs, path := setupService(handler)
defer os.Remove(path)
// Start the service
if err := cs.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer cs.Stop()
// Add a job then runs 100ms from now
target := time.Now().Add(100 * time.Millisecond).UnixMilli()
job, _ := cs.AddJob("FastJob", CronSchedule{Kind: "at", AtMS: &target}, "", false, "", "")
// Check for job execution with a timeout
success := false
for range 20 {
mu.Lock()
if executedJobs[job.ID] {
success = true
mu.Unlock()
break
}
mu.Unlock()
time.Sleep(100 * time.Millisecond)
}
if !success {
t.Error("Job was not executed in time")
}
// check that the job is removed after execution (DeleteAfterRun = true)
status := cs.Status()
if status["jobs"].(int) != 0 {
t.Errorf("Job should be deleted after run, got count: %v", status["jobs"])
}
}
func TestCronService_PersistenceIntegrity(t *testing.T) {
tmpFile := "persist_test.json"
defer os.Remove(tmpFile)
// write a job and persist
cs1 := NewCronService(tmpFile, nil)
at := int64(2000000000000)
cs1.AddJob("PersistMe", CronSchedule{Kind: "at", AtMS: &at}, "payload", true, "ch1", "")
// check file exists
if _, err := os.Stat(tmpFile); os.IsNotExist(err) {
t.Fatal("Store file was not created")
}
// reload and check data integrity
cs2 := NewCronService(tmpFile, nil)
if err := cs2.Load(); err != nil {
t.Fatalf("Failed to load store: %v", err)
}
jobs := cs2.ListJobs(true)
if len(jobs) != 1 || jobs[0].Name != "PersistMe" {
t.Errorf("Data corruption after reload. Got: %+v", jobs)
}
// test loading invalid JSON
os.WriteFile(tmpFile, []byte("{invalid json}"), 0o644)
cs3 := NewCronService(tmpFile, nil)
err := cs3.loadStore()
if err == nil {
t.Error("Should return error when loading invalid JSON")
}
}
func TestCronService_ConcurrentAccess(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)
cs.Start()
defer cs.Stop()
var wg sync.WaitGroup
workers := 10
iterations := 50
wg.Add(workers * 2)
// add jobs concurrently
for i := range workers {
go func(id int) {
defer wg.Done()
for j := range iterations {
at := time.Now().Add(time.Hour).UnixMilli()
cs.AddJob(fmt.Sprintf("Job-%d-%d", id, j), CronSchedule{Kind: "at", AtMS: &at}, "", false, "", "")
time.Sleep(100 * time.Microsecond)
}
}(i)
}
// read and update jobs concurrently
for range workers {
go func() {
defer wg.Done()
for j := range iterations {
jobs := cs.ListJobs(true)
if len(jobs) > 0 {
cs.EnableJob(jobs[0].ID, j%2 == 0)
}
time.Sleep(100 * time.Microsecond)
}
}()
}
wg.Wait()
}
+17 -13
View File
@@ -51,7 +51,7 @@ func init() {
FormatFieldValue: formatFieldValue,
}
logger = zerolog.New(consoleWriter).With().Timestamp().Logger()
logger = zerolog.New(consoleWriter).With().Timestamp().Caller().Logger()
fileLogger = zerolog.Logger{}
})
}
@@ -94,6 +94,12 @@ func SetLevel(level LogLevel) {
zerolog.SetGlobalLevel(level)
}
func SetConsoleLevel(level LogLevel) {
mu.Lock()
defer mu.Unlock()
logger = logger.Level(level)
}
func GetLevel() LogLevel {
mu.RLock()
defer mu.RUnlock()
@@ -134,9 +140,9 @@ func DisableFileLogging() {
fileLogger = zerolog.Logger{}
}
func getCallerInfo() (string, int, string) {
func getCallerSkip() int {
for i := 2; i < 15; i++ {
pc, file, line, ok := runtime.Caller(i)
pc, file, _, ok := runtime.Caller(i)
if !ok {
continue
}
@@ -158,10 +164,10 @@ func getCallerInfo() (string, int, string) {
continue
}
return filepath.Base(file), line, filepath.Base(funcName)
return i - 1
}
return "???", 0, "???"
return 3
}
//nolint:zerologlint
@@ -187,19 +193,16 @@ func logMessage(level LogLevel, component string, message string, fields map[str
return
}
callerFile, callerLine, callerFunc := getCallerInfo()
skip := getCallerSkip()
event := getEvent(logger, level)
// Build combined field with component and caller
if component != "" {
event.Str("caller", fmt.Sprintf("%-6s %s:%d (%s)", component, callerFile, callerLine, callerFunc))
} else {
event.Str("caller", fmt.Sprintf("<none> %s:%d (%s)", callerFile, callerLine, callerFunc))
event.Str("component", component)
}
appendFields(event, fields)
event.Msg(message)
event.CallerSkipFrame(skip).Msg(message)
// Also log to file if enabled
if fileLogger.GetLevel() != zerolog.NoLevel {
@@ -208,9 +211,10 @@ func logMessage(level LogLevel, component string, message string, fields map[str
if component != "" {
fileEvent.Str("component", component)
}
// fileEvent.Str("caller", fmt.Sprintf("%s:%d (%s)", callerFile, callerLine, callerFunc))
appendFields(event, fields)
fileEvent.Msg(message)
appendFields(fileEvent, fields)
fileEvent.CallerSkipFrame(skip).Msg(message)
}
if level == FATAL {
+3 -1
View File
@@ -5,13 +5,15 @@ import (
"io"
"os"
"path/filepath"
"github.com/sipeed/picoclaw/pkg/config"
)
func ResolveTargetHome(override string) (string, error) {
if override != "" {
return ExpandHome(override), nil
}
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
if envHome := os.Getenv(config.EnvHome); envHome != "" {
return ExpandHome(envHome), nil
}
home, err := os.UserHomeDir()
+15 -11
View File
@@ -132,11 +132,12 @@ type OpenClawChannels struct {
}
type OpenClawTelegramConfig struct {
BotToken *string `json:"botToken"`
AllowFrom []string `json:"allowFrom"`
GroupPolicy *string `json:"groupPolicy"`
DmPolicy *string `json:"dmPolicy"`
Enabled *bool `json:"enabled"`
BotToken *string `json:"botToken"`
AllowFrom []string `json:"allowFrom"`
GroupPolicy *string `json:"groupPolicy"`
DmPolicy *string `json:"dmPolicy"`
Enabled *bool `json:"enabled"`
UseMarkdownV2 *bool `json:"useMarkdownV2"`
}
type OpenClawDiscordConfig struct {
@@ -645,10 +646,11 @@ type WhatsAppConfig struct {
}
type TelegramConfig struct {
Enabled bool `json:"enabled"`
Token string `json:"token"`
Proxy string `json:"proxy"`
AllowFrom []string `json:"allow_from"`
Enabled bool `json:"enabled"`
Token string `json:"token"`
Proxy string `json:"proxy"`
AllowFrom []string `json:"allow_from"`
UseMarkdownV2 bool `json:"use_markdown_v2"`
}
type FeishuConfig struct {
@@ -777,9 +779,11 @@ func (c *OpenClawConfig) convertChannels(warnings *[]string) ChannelsConfig {
if c.Channels.Telegram != nil {
enabled := c.Channels.Telegram.Enabled == nil || *c.Channels.Telegram.Enabled
useMarkdownV2 := c.Channels.Telegram.UseMarkdownV2 != nil && *c.Channels.Telegram.UseMarkdownV2
channels.Telegram = TelegramConfig{
Enabled: enabled,
AllowFrom: c.Channels.Telegram.AllowFrom,
Enabled: enabled,
AllowFrom: c.Channels.Telegram.AllowFrom,
UseMarkdownV2: useMarkdownV2,
}
if c.Channels.Telegram.BotToken != nil {
channels.Telegram.Token = *c.Channels.Telegram.BotToken
@@ -10,6 +10,11 @@ import (
"github.com/sipeed/picoclaw/pkg/migrate/internal"
)
// OpenclawHomeEnvVar is the environment variable that overrides the source
// openclaw home directory when migrating from openclaw to picoclaw.
// Default: ~/.openclaw
const OpenclawHomeEnvVar = "OPENCLAW_HOME"
var providerMapping = map[string]string{
"anthropic": "anthropic",
"claude": "anthropic",
@@ -112,7 +117,7 @@ func resolveSourceHome(override string) (string, error) {
if override != "" {
return internal.ExpandHome(override), nil
}
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
if envHome := os.Getenv(OpenclawHomeEnvVar); envHome != "" {
return internal.ExpandHome(envHome), nil
}
home, err := os.UserHomeDir()
+4
View File
@@ -180,6 +180,10 @@ func buildParams(
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
// Skip tool calls with empty names to avoid API errors
if tc.Name == "" {
continue
}
args := tc.Arguments
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
+10 -2
View File
@@ -50,10 +50,18 @@ func (p *ClaudeCliProvider) Chat(
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if stderrStr := stderr.String(); stderrStr != "" {
stderrStr := strings.TrimSpace(stderr.String())
stdoutStr := strings.TrimSpace(stdout.String())
switch {
case stderrStr != "" && stdoutStr != "":
return nil, fmt.Errorf("claude cli error: %w\nstderr: %s\nstdout: %s", err, stderrStr, stdoutStr)
case stderrStr != "":
return nil, fmt.Errorf("claude cli error: %s", stderrStr)
case stdoutStr != "":
return nil, fmt.Errorf("claude cli error: %w\noutput: %s", err, stdoutStr)
default:
return nil, fmt.Errorf("claude cli error: %w", err)
}
return nil, fmt.Errorf("claude cli error: %w", err)
}
return p.parseClaudeCliResponse(stdout.String())
+6 -1
View File
@@ -8,6 +8,11 @@ import (
"time"
)
// CodexHomeEnvVar is the environment variable that overrides the Codex CLI
// home directory when resolving the codex auth.json credentials file.
// Default: ~/.codex
const CodexHomeEnvVar = "CODEX_HOME"
// CodexCliAuth represents the ~/.codex/auth.json file structure.
type CodexCliAuth struct {
Tokens struct {
@@ -69,7 +74,7 @@ func CreateCodexCliTokenSource() func() (string, string, error) {
}
func resolveCodexAuthPath() (string, error) {
codexHome := os.Getenv("CODEX_HOME")
codexHome := os.Getenv(CodexHomeEnvVar)
if codexHome == "" {
home, err := os.UserHomeDir()
if err != nil {
+8 -1
View File
@@ -95,7 +95,10 @@ func (p *CodexProvider) Chat(
)
}
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
// Respect tools.web.prefer_native: only inject native search when the agent
// loop requested it (options["native_search"]), so prefer_native: false
useNativeSearch := p.enableWebSearch && (options["native_search"] == true)
params := buildCodexParams(messages, tools, resolvedModel, options, useNativeSearch)
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
defer stream.Close()
@@ -157,6 +160,10 @@ func (p *CodexProvider) GetDefaultModel() string {
return codexDefaultModel
}
func (p *CodexProvider) SupportsNativeSearch() bool {
return p.enableWebSearch
}
func resolveCodexModel(model string) (string, string) {
m := strings.ToLower(strings.TrimSpace(model))
if m == "" {
+3 -1
View File
@@ -355,7 +355,9 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{"max_tokens": 1024})
// Pass native_search so Codex injects built-in web search (mirrors agent loop when prefer_native is true).
opts := map[string]any{"max_tokens": 1024, "native_search": true}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", opts)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
+5 -3
View File
@@ -55,8 +55,8 @@ func ExtractProtocol(model string) (protocol, modelID string) {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocols: openai, litellm, anthropic, anthropic-messages, antigravity,
// claude-cli, codex-cli, github-copilot
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages,
// antigravity, claude-cli, codex-cli, github-copilot
// Returns the provider, the model ID (without protocol prefix), and any error.
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
if cfg == nil {
@@ -116,7 +116,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
"minimax", "longcat", "modelscope":
"minimax", "longcat", "modelscope", "novita":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -219,6 +219,8 @@ func getDefaultAPIBase(protocol string) string {
return "https://openrouter.ai/api/v1"
case "litellm":
return "http://localhost:4000/v1"
case "novita":
return "https://api.novita.ai/openai"
case "groq":
return "https://api.groq.com/openai/v1"
case "zhipu":
+29
View File
@@ -112,6 +112,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
}{
{"openai", "openai"},
{"groq", "groq"},
{"novita", "novita"},
{"openrouter", "openrouter"},
{"cerebras", "cerebras"},
{"vivgrid", "vivgrid"},
@@ -222,6 +223,34 @@ func TestGetDefaultAPIBase_ModelScope(t *testing.T) {
}
}
func TestCreateProviderFromConfig_Novita(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-novita",
Model: "novita/deepseek/deepseek-v3.2",
APIKey: "test-key",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "deepseek/deepseek-v3.2" {
t.Errorf("modelID = %q, want %q", modelID, "deepseek/deepseek-v3.2")
}
if _, ok := provider.(*HTTPProvider); !ok {
t.Fatalf("expected *HTTPProvider, got %T", provider)
}
}
func TestGetDefaultAPIBase_Novita(t *testing.T) {
if got := getDefaultAPIBase("novita"); got != "https://api.novita.ai/openai" {
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "novita", got, "https://api.novita.ai/openai")
}
}
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-anthropic",
+4
View File
@@ -55,3 +55,7 @@ func (p *HTTPProvider) Chat(
func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func (p *HTTPProvider) SupportsNativeSearch() bool {
return p.delegate.SupportsNativeSearch()
}
+33 -3
View File
@@ -103,8 +103,11 @@ func (p *Provider) Chat(
"messages": common.SerializeMessages(messages),
}
if len(tools) > 0 {
requestBody["tools"] = tools
// When fallback uses a different provider (e.g. DeepSeek), that provider must not inject web_search_preview.
nativeSearch, _ := options["native_search"].(bool)
nativeSearch = nativeSearch && isNativeSearchHost(p.apiBase)
if len(tools) > 0 || nativeSearch {
requestBody["tools"] = buildToolsList(tools, nativeSearch)
requestBody["tool_choice"] = "auto"
}
@@ -188,13 +191,40 @@ func normalizeModel(model, apiBase string) string {
prefix := strings.ToLower(before)
switch prefix {
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google",
"openrouter", "zhipu", "mistral", "vivgrid", "minimax":
"openrouter", "zhipu", "mistral", "vivgrid", "minimax", "novita":
return after
default:
return model
}
}
func buildToolsList(tools []ToolDefinition, nativeSearch bool) []any {
result := make([]any, 0, len(tools)+1)
for _, t := range tools {
if nativeSearch && strings.EqualFold(t.Function.Name, "web_search") {
continue
}
result = append(result, t)
}
if nativeSearch {
result = append(result, map[string]any{"type": "web_search_preview"})
}
return result
}
func (p *Provider) SupportsNativeSearch() bool {
return isNativeSearchHost(p.apiBase)
}
func isNativeSearchHost(apiBase string) bool {
u, err := url.Parse(apiBase)
if err != nil {
return false
}
host := u.Hostname()
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
}
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API and Azure OpenAI support this. All other OpenAI-compatible providers
+269 -22
View File
@@ -432,7 +432,28 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
}
}
func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
func TestProviderChat_StripsGroqOllamaDeepseekVivgridNovitaPrefixes(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
tests := []struct {
name string
input string
@@ -463,31 +484,25 @@ func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
input: "vivgrid/auto",
wantModel: "auto",
},
{
name: "strips novita prefix deepseek model",
input: "novita/deepseek/deepseek-v3.2",
wantModel: "deepseek/deepseek-v3.2",
},
{
name: "strips novita prefix zai model",
input: "novita/zai-org/glm-5",
wantModel: "zai-org/glm-5",
},
{
name: "strips novita prefix minimax model",
input: "novita/minimax/minimax-m2.5",
wantModel: "minimax/minimax-m2.5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -573,6 +588,12 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
if got := normalizeModel("vivgrid/auto", "https://api.vivgrid.com/v1"); got != "auto" {
t.Fatalf("normalizeModel(vivgrid auto) = %q, want %q", got, "auto")
}
if got := normalizeModel(
"novita/deepseek/deepseek-v3.2",
"https://api.novita.ai/openai",
); got != "deepseek/deepseek-v3.2" {
t.Fatalf("normalizeModel(novita) = %q, want %q", got, "deepseek/deepseek-v3.2")
}
}
func TestProvider_RequestTimeoutDefault(t *testing.T) {
@@ -824,6 +845,232 @@ func TestSupportsPromptCacheKey(t *testing.T) {
}
}
func TestBuildToolsList_NativeSearchAddsWebSearchPreview(t *testing.T) {
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
}
result := buildToolsList(tools, true)
if len(result) != 2 {
t.Fatalf("len(result) = %d, want 2", len(result))
}
wsEntry, ok := result[1].(map[string]any)
if !ok {
t.Fatalf("web search entry is %T, want map[string]any", result[1])
}
if wsEntry["type"] != "web_search_preview" {
t.Fatalf("type = %v, want web_search_preview", wsEntry["type"])
}
}
func TestBuildToolsList_NativeSearchFiltersClientWebSearch(t *testing.T) {
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
}
result := buildToolsList(tools, true)
for _, entry := range result {
if td, ok := entry.(ToolDefinition); ok && strings.EqualFold(td.Function.Name, "web_search") {
t.Fatal("client-side web_search should be filtered out when native search is enabled")
}
}
if len(result) != 2 { // read_file + web_search_preview
t.Fatalf("len(result) = %d, want 2 (read_file + web_search_preview)", len(result))
}
}
func TestBuildToolsList_NoNativeSearchPassesThrough(t *testing.T) {
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
}
result := buildToolsList(tools, false)
if len(result) != 2 {
t.Fatalf("len(result) = %d, want 2", len(result))
}
}
func TestIsNativeSearchHost(t *testing.T) {
tests := []struct {
apiBase string
want bool
}{
{"https://api.openai.com/v1", true},
{"https://myresource.openai.azure.com/openai/deployments/gpt-4", true},
{"https://api.mistral.ai/v1", false},
{"https://api.deepseek.com/v1", false},
{"https://api.groq.com/openai/v1", false},
{"http://localhost:11434/v1", false},
{"", false},
}
for _, tt := range tests {
if got := isNativeSearchHost(tt.apiBase); got != tt.want {
t.Errorf("isNativeSearchHost(%q) = %v, want %v", tt.apiBase, got, tt.want)
}
}
}
func TestSupportsNativeSearch_OpenAI(t *testing.T) {
p := NewProvider("key", "https://api.openai.com/v1", "")
if !p.SupportsNativeSearch() {
t.Fatal("OpenAI provider should support native search")
}
}
func TestSupportsNativeSearch_NonOpenAI(t *testing.T) {
p := NewProvider("key", "https://api.deepseek.com/v1", "")
if p.SupportsNativeSearch() {
t.Fatal("DeepSeek provider should not support native search")
}
}
func TestProviderChat_NativeSearchToolInjected(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
p.apiBase = "https://api.openai.com/v1"
p.httpClient = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
r.URL, _ = url.Parse(server.URL + r.URL.Path)
return http.DefaultTransport.RoundTrip(r)
}),
}
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
}
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
tools,
"gpt-5.4",
map[string]any{"native_search": true},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
toolsRaw, ok := requestBody["tools"].([]any)
if !ok {
t.Fatalf("tools is %T, want []any", requestBody["tools"])
}
if len(toolsRaw) != 2 {
t.Fatalf("len(tools) = %d, want 2 (read_file + web_search_preview)", len(toolsRaw))
}
lastTool, ok := toolsRaw[1].(map[string]any)
if !ok {
t.Fatalf("last tool is %T, want map[string]any", toolsRaw[1])
}
if lastTool["type"] != "web_search_preview" {
t.Fatalf("last tool type = %v, want web_search_preview", lastTool["type"])
}
}
func TestProviderChat_NativeSearchNotInjectedWithoutOption(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
}
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
tools,
"gpt-5.4",
map[string]any{},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
toolsRaw, ok := requestBody["tools"].([]any)
if !ok {
t.Fatalf("tools is %T, want []any", requestBody["tools"])
}
if len(toolsRaw) != 1 {
t.Fatalf("len(tools) = %d, want 1 (web_search only)", len(toolsRaw))
}
}
// TestProviderChat_NativeSearchIgnoredOnNonOpenAI verifies that when native_search
// is true in options but the provider's apiBase is not OpenAI (e.g. fallback to DeepSeek),
// we do not inject web_search_preview to avoid API errors.
func TestProviderChat_NativeSearchIgnoredOnNonOpenAI(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
// Use server.URL so host is not api.openai.com — simulates DeepSeek/other provider
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"deepseek-chat",
map[string]any{"native_search": true},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// Should not have tools at all (no tools passed, and we must not add web_search_preview)
if toolsRaw, ok := requestBody["tools"]; ok {
t.Fatalf("tools should be omitted for non-OpenAI when only native_search was requested, got %v", toolsRaw)
}
}
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
messages := []protocoltypes.Message{
{
+9
View File
@@ -44,6 +44,15 @@ type ThinkingCapable interface {
SupportsThinking() bool
}
// NativeSearchCapable is an optional interface for providers that support
// built-in web search during LLM inference (e.g. OpenAI web_search_preview,
// xAI Grok search). When the active provider implements this interface and
// returns true, the agent loop can hide the client-side web_search tool to
// avoid duplicate search surfaces and use the provider's native search instead.
type NativeSearchCapable interface {
SupportsNativeSearch() bool
}
// FailoverReason classifies why an LLM request failed for fallback decisions.
type FailoverReason string
+6 -3
View File
@@ -226,9 +226,12 @@ func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatal("expected outbound message")
var msg bus.OutboundMessage
select {
case msg = <-tool.msgBus.OutboundChan():
// got message
case <-ctx.Done():
t.Fatal("timeout waiting for outbound message")
}
if !strings.Contains(msg.Content, "command execution is disabled") {
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
+174 -25
View File
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
@@ -15,6 +16,7 @@ import (
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -776,22 +778,49 @@ type WebFetchTool struct {
maxChars int
proxy string
client *http.Client
format string
fetchLimitBytes int64
whitelist *privateHostWhitelist
}
func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
type privateHostWhitelist struct {
exact map[string]struct{}
cidrs []*net.IPNet
}
func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
// createHTTPClient cannot fail with an empty proxy string.
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
return NewWebFetchToolWithConfig(maxChars, "", format, fetchLimitBytes, nil)
}
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily.
var allowPrivateWebFetchHosts atomic.Bool
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
func NewWebFetchToolWithProxy(
maxChars int,
proxy string,
format string,
fetchLimitBytes int64,
privateHostWhitelist []string,
) (*WebFetchTool, error) {
return NewWebFetchToolWithConfig(maxChars, proxy, format, fetchLimitBytes, privateHostWhitelist)
}
func NewWebFetchToolWithConfig(
maxChars int,
proxy string,
format string,
fetchLimitBytes int64,
privateHostWhitelist []string,
) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = defaultMaxChars
}
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
if err != nil {
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
}
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
@@ -801,13 +830,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
}
transport.DialContext = newSafeDialContext(dialer)
transport.DialContext = newSafeDialContext(dialer, whitelist)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
if isObviousPrivateHost(req.URL.Hostname()) {
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
return fmt.Errorf("redirect target is private or local network host")
}
return nil
@@ -819,7 +848,9 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
maxChars: maxChars,
proxy: proxy,
client: client,
format: format,
fetchLimitBytes: fetchLimitBytes,
whitelist: whitelist,
}, nil
}
@@ -871,7 +902,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
// The real SSRF guard is newSafeDialContext at connect time.
hostname := parsedURL.Hostname()
if isObviousPrivateHost(hostname) {
if isObviousPrivateHost(hostname, t.whitelist) {
return ErrorResult("fetching private or local network hosts is not allowed")
}
@@ -906,26 +937,68 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
}
bodyStr := string(body)
contentType := resp.Header.Get("Content-Type")
mediaType, params, err := mime.ParseMediaType(contentType)
if err != nil {
// The most common error here is "mime: no media type" if the header is empty.
logger.WarnCF("tool", "Failed to parse Content-Type", map[string]any{
"raw_header": contentType,
"error": err.Error(),
})
// security fallback
mediaType = "application/octet-stream"
}
charset, hasCharset := params["charset"]
if hasCharset {
// If the charset is not utf-8, we might have to convert the bodyStr
// before passing it to the HTML/Markdown parser
if strings.ToLower(charset) != "utf-8" {
logger.WarnCF("tool", "Note: the content is not in UTF-8", map[string]any{"charset": charset})
}
}
var text, extractor string
if strings.Contains(contentType, "application/json") {
switch {
case mediaType == "application/json":
var jsonData any
if err := json.Unmarshal(body, &jsonData); err == nil {
formatted, _ := json.MarshalIndent(jsonData, "", " ")
text = string(formatted)
extractor = "json"
} else {
text = string(body)
if err := json.Unmarshal(body, &jsonData); err != nil {
text = bodyStr
extractor = "raw"
break
}
} else if strings.Contains(contentType, "text/html") || len(body) > 0 &&
(strings.HasPrefix(string(body), "<!DOCTYPE") || strings.HasPrefix(strings.ToLower(string(body)), "<html")) {
text = t.extractText(string(body))
extractor = "text"
} else {
text = string(body)
formatted, err := json.MarshalIndent(jsonData, "", " ")
if err != nil {
text = bodyStr
extractor = "raw"
break
}
text = string(formatted)
extractor = "json"
case mediaType == "text/html" || looksLikeHTML(bodyStr):
switch strings.ToLower(t.format) {
case "markdown":
var err error
text, err = utils.HtmlToMarkdown(bodyStr)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to HTML to markdown: %v", err))
}
extractor = "markdown"
default:
text = t.extractText(bodyStr)
extractor = "text"
}
default:
text = bodyStr
extractor = "raw"
}
@@ -957,6 +1030,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
}
}
func looksLikeHTML(body string) bool {
if body == "" {
return false
}
lower := strings.ToLower(body)
return strings.HasPrefix(body, "<!doctype") ||
strings.HasPrefix(lower, "<html")
}
func (t *WebFetchTool) extractText(htmlContent string) string {
result := reScript.ReplaceAllLiteralString(htmlContent, "")
result = reStyle.ReplaceAllLiteralString(result, "")
@@ -981,7 +1065,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
func newSafeDialContext(
dialer *net.Dialer,
whitelist *privateHostWhitelist,
) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
if allowPrivateWebFetchHosts.Load() {
return dialer.DialContext(ctx, network, address)
@@ -996,7 +1083,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if ip := net.ParseIP(host); ip != nil {
if isPrivateOrRestrictedIP(ip) {
if shouldBlockPrivateIP(ip, whitelist) {
return nil, fmt.Errorf("blocked private or local target: %s", host)
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
@@ -1010,7 +1097,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
attempted := 0
var lastErr error
for _, ipAddr := range ipAddrs {
if isPrivateOrRestrictedIP(ipAddr.IP) {
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
continue
}
attempted++
@@ -1022,7 +1109,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
if attempted == 0 {
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
}
if lastErr != nil {
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
@@ -1031,10 +1118,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
}
}
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
if len(entries) == 0 {
return nil, nil
}
whitelist := &privateHostWhitelist{
exact: make(map[string]struct{}),
cidrs: make([]*net.IPNet, 0, len(entries)),
}
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
if ip := net.ParseIP(entry); ip != nil {
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
continue
}
_, network, err := net.ParseCIDR(entry)
if err != nil {
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
}
whitelist.cidrs = append(whitelist.cidrs, network)
}
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
return nil, nil
}
return whitelist, nil
}
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
if w == nil || ip == nil {
return false
}
normalized := normalizeWhitelistIP(ip)
if _, ok := w.exact[normalized.String()]; ok {
return true
}
for _, network := range w.cidrs {
if network.Contains(normalized) {
return true
}
}
return false
}
func normalizeWhitelistIP(ip net.IP) net.IP {
if ip == nil {
return nil
}
if ip4 := ip.To4(); ip4 != nil {
return ip4
}
return ip
}
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
}
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
func isObviousPrivateHost(host string) bool {
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
if allowPrivateWebFetchHosts.Load() {
return false
}
@@ -1050,7 +1199,7 @@ func isObviousPrivateHost(host string) bool {
}
if ip := net.ParseIP(h); ip != nil {
return isPrivateOrRestrictedIP(ip)
return shouldBlockPrivateIP(ip, whitelist)
}
return false
+170 -20
View File
@@ -10,11 +10,15 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
const testFetchLimit = int64(10 * 1024 * 1024)
const (
testFetchLimit = int64(10 * 1024 * 1024)
format = "plaintext"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
@@ -27,7 +31,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -69,7 +73,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -94,7 +98,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -119,7 +123,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -144,7 +148,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -178,7 +182,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars
tool, err := NewWebFetchTool(1000, format, testFetchLimit) // Limit to 1000 chars
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -228,7 +232,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
defer ts.Close()
// Initialize the tool
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -311,7 +315,7 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -423,8 +427,31 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
})
}
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
t.Helper()
hostPort := strings.TrimPrefix(rawURL, "http://")
hostPort = strings.TrimPrefix(hostPort, "https://")
host, port, err := net.SplitHostPort(hostPort)
if err != nil {
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
}
return host, port
}
func singleHostCIDR(t *testing.T, host string) string {
t.Helper()
ip := net.ParseIP(host)
if ip == nil {
t.Fatalf("failed to parse IP %q", host)
}
if ip.To4() != nil {
return ip.String() + "/32"
}
return ip.String() + "/128"
}
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -441,6 +468,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
}
}
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("exact whitelist ok"))
}))
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{host})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
t.Fatalf("expected fetched content, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("cidr whitelist ok"))
}))
defer server.Close()
host, _ := serverHostAndPort(t, server.URL)
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{singleHostCIDR(t, host)})
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
result := tool.Execute(context.Background(), map[string]any{
"url": server.URL,
})
if result.IsError {
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
t.Fatalf("expected fetched content, got %q", result.ForLLM)
}
}
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
withPrivateWebFetchHostsAllowed(t)
@@ -451,7 +528,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
}))
defer server.Close()
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -466,7 +543,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
// TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked
func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -481,7 +558,7 @@ func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
// TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked
func TestWebFetch_BlocksMetadataIP(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -496,7 +573,7 @@ func TestWebFetch_BlocksMetadataIP(t *testing.T) {
// TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked
func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -511,7 +588,7 @@ func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
// TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked
func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -527,7 +604,7 @@ func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
// TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked
func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -557,7 +634,7 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
allowPrivateWebFetchHosts.Store(false)
defer allowPrivateWebFetchHosts.Store(true)
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
t.Fatalf("Failed to create web fetch tool: %v", err)
}
@@ -570,6 +647,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
}
}
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on loopback: %v", err)
}
defer listener.Close()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("failed to split listener address: %v", err)
}
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
if err == nil {
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
}
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on loopback: %v", err)
}
defer listener.Close()
accepted := make(chan struct{}, 1)
go func() {
conn, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
conn.Close()
accepted <- struct{}{}
}()
_, port, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
t.Fatalf("failed to split listener address: %v", err)
}
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
if err != nil {
t.Fatalf("failed to parse whitelist: %v", err)
}
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
if err != nil {
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
}
conn.Close()
select {
case <-accepted:
case <-time.After(time.Second):
t.Fatal("expected localhost listener to accept a connection")
}
}
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
tests := []struct {
@@ -615,7 +755,7 @@ func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool, err := NewWebFetchTool(50000, testFetchLimit)
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -639,7 +779,7 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
}
func TestNewWebFetchToolWithProxy(t *testing.T) {
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit)
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", format, testFetchLimit, nil)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else if tool.maxChars != 1024 {
@@ -650,7 +790,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
}
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit)
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", format, testFetchLimit, nil)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
}
@@ -660,6 +800,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
}
}
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
_, err := NewWebFetchToolWithConfig(1024, "", format, testFetchLimit, []string{"not-an-ip-or-cidr"})
if err == nil {
t.Fatal("expected invalid whitelist entry to fail")
}
if !strings.Contains(err.Error(), "invalid entry") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
+411
View File
@@ -0,0 +1,411 @@
package utils
import (
"bytes"
"net/url"
"regexp"
"strconv"
"strings"
"golang.org/x/net/html"
)
var (
reSpaces = regexp.MustCompile(`[ \t]+`)
reNewlines = regexp.MustCompile(`\n{3,}`)
reEmptyListItem = regexp.MustCompile(`(?m)^[-*]\s*$`)
reImageOnlyLink = regexp.MustCompile(`\[!\[\]\(<[^>]*>\)\]\(<[^>]*>\)`)
reEmptyHeader = regexp.MustCompile(`(?m)^#{1,6}\s*$`)
reLeadingLineSpace = regexp.MustCompile(`(?m)^([ \t])([^ \t\n])`)
)
var skipTags = map[string]bool{
"script": true, "style": true, "head": true,
"noscript": true, "template": true,
"nav": true, "footer": true, "aside": true, "header": true, "form": true, "dialog": true,
}
func isSafeHref(href string) bool {
lower := strings.ToLower(strings.TrimSpace(href))
if strings.HasPrefix(lower, "javascript:") || strings.HasPrefix(lower, "vbscript:") ||
strings.HasPrefix(lower, "data:") {
return false
}
u, err := url.Parse(strings.TrimSpace(href))
if err != nil {
return false
}
scheme := strings.ToLower(u.Scheme)
return scheme == "" || scheme == "http" || scheme == "https" || scheme == "mailto"
}
func isSafeImageSrc(src string) bool {
lower := strings.ToLower(strings.TrimSpace(src))
if strings.HasPrefix(lower, "data:image/") {
return true
}
return isSafeHref(src)
}
func escapeMdAlt(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `[`, `\[`)
s = strings.ReplaceAll(s, `]`, `\]`)
return s
}
func getAttr(n *html.Node, key string) string {
for _, a := range n.Attr {
if a.Key == key {
return a.Val
}
}
return ""
}
func normalizeAttr(val string) string {
val = strings.ReplaceAll(val, "\n", "")
val = strings.ReplaceAll(val, "\r", "")
val = strings.ReplaceAll(val, "\t", "")
return strings.TrimSpace(val)
}
func isUnlikelyNode(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
classId := strings.ToLower(getAttr(n, "class") + " " + getAttr(n, "id"))
if classId == " " {
return false
}
if strings.Contains(classId, "article") || strings.Contains(classId, "main") ||
strings.Contains(classId, "content") {
return false
}
unlikelyKeywords := []string{
"menu",
"nav",
"footer",
"sidebar",
"cookie",
"banner",
"sponsor",
"advert",
"popup",
"modal",
"newsletter",
"share",
"social",
}
for _, keyword := range unlikelyKeywords {
if strings.Contains(classId, keyword) {
return true
}
}
return false
}
type converter struct {
stack []*bytes.Buffer
linkHrefs []string
linkStates []bool
emphStack []string // Tracks "**", "*", "~~" for buffered emphasis
olCounters []int
inPre bool
listDepth int
}
func newConverter() *converter {
return &converter{
stack: []*bytes.Buffer{{}},
}
}
func (c *converter) write(s string) {
c.stack[len(c.stack)-1].WriteString(s)
}
func (c *converter) pushBuf() {
c.stack = append(c.stack, &bytes.Buffer{})
}
func (c *converter) popBuf() string {
top := c.stack[len(c.stack)-1]
c.stack = c.stack[:len(c.stack)-1]
return top.String()
}
func (c *converter) walk(n *html.Node) {
if n.Type == html.ElementNode {
if skipTags[n.Data] {
return
}
if isUnlikelyNode(n) {
return
}
}
if n.Type == html.TextNode {
text := n.Data
if !c.inPre {
text = strings.ReplaceAll(text, "\n", " ")
text = reSpaces.ReplaceAllString(text, " ")
}
if text != "" {
c.write(text)
}
return
}
if n.Type != html.ElementNode {
for ch := n.FirstChild; ch != nil; ch = ch.NextSibling {
c.walk(ch)
}
return
}
// Opening Tags
switch n.Data {
// Buffer emphasis content so we can TrimSpace the inner text,
// avoiding the regex-across-boundaries bug.
case "b", "strong":
c.emphStack = append(c.emphStack, "**")
c.pushBuf()
case "i", "em":
c.emphStack = append(c.emphStack, "*")
c.pushBuf()
case "del", "s":
c.emphStack = append(c.emphStack, "~~")
c.pushBuf()
case "a":
href := normalizeAttr(getAttr(n, "href"))
if href != "" && !isSafeHref(href) {
href = "#"
}
hasHref := href != ""
c.linkStates = append(c.linkStates, hasHref)
if hasHref {
c.linkHrefs = append(c.linkHrefs, href)
c.pushBuf()
}
case "h1":
c.write("\n\n# ")
case "h2":
c.write("\n\n## ")
case "h3":
c.write("\n\n### ")
case "h4":
c.write("\n\n#### ")
case "h5":
c.write("\n\n##### ")
case "h6":
c.write("\n\n###### ")
case "p":
c.write("\n\n")
case "br":
c.write("\n")
case "hr":
c.write("\n\n---\n\n")
case "ol":
c.olCounters = append(c.olCounters, 1)
// Only write leading newline for top-level list.
if c.listDepth == 0 {
c.write("\n")
}
c.listDepth++
case "ul":
if c.listDepth == 0 {
c.write("\n")
}
c.listDepth++
case "li":
c.write("\n")
if c.listDepth > 1 {
c.write(strings.Repeat(" ", c.listDepth-1))
}
if n.Parent != nil && n.Parent.Data == "ol" && len(c.olCounters) > 0 {
idx := c.olCounters[len(c.olCounters)-1]
c.write(strconv.Itoa(idx) + ". ")
c.olCounters[len(c.olCounters)-1]++
} else {
c.write("- ")
}
case "pre":
c.inPre = true
c.write("\n\n```\n")
case "code":
if !c.inPre {
c.write("`")
}
case "blockquote":
c.pushBuf()
for ch := n.FirstChild; ch != nil; ch = ch.NextSibling {
c.walk(ch)
}
inner := strings.TrimSpace(c.popBuf())
lines := strings.Split(inner, "\n")
var quoted []string
for _, l := range lines {
if strings.TrimSpace(l) == "" {
quoted = append(quoted, ">")
} else {
quoted = append(quoted, "> "+l)
}
}
var deduped []string
for i, line := range quoted {
if line == ">" && i > 0 && deduped[len(deduped)-1] == ">" {
continue
}
deduped = append(deduped, line)
}
c.write("\n\n" + strings.Join(deduped, "\n") + "\n\n")
return
case "img":
src := normalizeAttr(getAttr(n, "src"))
if src == "" {
src = normalizeAttr(getAttr(n, "data-src"))
}
if src == "" {
return
}
alt := escapeMdAlt(normalizeAttr(getAttr(n, "alt")))
if isSafeImageSrc(src) {
c.write("![" + alt + "](" + src + ")")
}
return
}
// Traverse Children
for ch := n.FirstChild; ch != nil; ch = ch.NextSibling {
c.walk(ch)
}
// Closing Tags
switch n.Data {
// Pop buffer, trim, wrap with the correct marker.
case "b", "strong", "i", "em", "del", "s":
if len(c.emphStack) == 0 {
break
}
marker := c.emphStack[len(c.emphStack)-1]
c.emphStack = c.emphStack[:len(c.emphStack)-1]
inner := strings.TrimSpace(c.popBuf())
if inner != "" {
c.write(marker + inner + marker)
}
case "a":
if len(c.linkStates) == 0 {
break
}
hasHref := c.linkStates[len(c.linkStates)-1]
c.linkStates = c.linkStates[:len(c.linkStates)-1]
if !hasHref {
break
}
href := c.linkHrefs[len(c.linkHrefs)-1]
c.linkHrefs = c.linkHrefs[:len(c.linkHrefs)-1]
inner := strings.TrimSpace(c.popBuf())
if strings.Contains(inner, "\n") {
lines := strings.Split(inner, "\n")
linked := false
for i, l := range lines {
cleanLine := strings.TrimSpace(l)
if cleanLine != "" && !strings.HasPrefix(cleanLine, "![") && !linked {
lines[i] = "[" + cleanLine + "](" + href + ")"
linked = true
}
}
c.write(strings.Join(lines, "\n"))
} else {
c.write("[" + inner + "](" + href + ")")
}
case "h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"p",
"div",
"section",
"article",
"header",
"footer",
"aside",
"nav",
"figure":
c.write("\n")
case "ol":
c.listDepth--
if len(c.olCounters) > 0 {
c.olCounters = c.olCounters[:len(c.olCounters)-1]
}
if c.listDepth == 0 {
c.write("\n")
}
case "ul":
c.listDepth--
if c.listDepth == 0 {
c.write("\n")
}
case "pre":
c.inPre = false
c.write("\n```\n\n")
case "code":
if !c.inPre {
c.write("`")
}
}
}
func HtmlToMarkdown(htmlStr string) (string, error) {
doc, err := html.Parse(strings.NewReader(htmlStr))
if err != nil {
return "", err
}
c := newConverter()
c.walk(doc)
res := c.stack[0].String()
// Post-processing
res = reImageOnlyLink.ReplaceAllString(res, "")
res = reEmptyListItem.ReplaceAllString(res, "")
res = reEmptyHeader.ReplaceAllString(res, "")
lines := strings.Split(res, "\n")
var cleanLines []string
for _, line := range lines {
line = strings.TrimRight(line, " \t")
cleanTest := strings.TrimSpace(line)
if cleanTest == "[](</>)" || cleanTest == "[](#)" || cleanTest == "-" {
cleanLines = append(cleanLines, "")
continue
}
cleanLines = append(cleanLines, line)
}
res = strings.Join(cleanLines, "\n")
res = strings.TrimSpace(res)
res = reNewlines.ReplaceAllString(res, "\n\n")
// Strip a single leading space from lines that are NOT list indentation.
// "(?m)^([ \t])([^ \t\n])" matches exactly one space/tab at line start followed
// by a non-whitespace char, so " - nested" (4 spaces) is left untouched.
res = reLeadingLineSpace.ReplaceAllString(res, "$2")
return res, nil
}
+245
View File
@@ -0,0 +1,245 @@
package utils
import (
"testing"
"github.com/sipeed/picoclaw/pkg/logger"
)
func TestHtmlToMarkdown(t *testing.T) {
// Define our test cases
tests := []struct {
name string
input string
expected string
}{
{
name: "Removes scripts and styles",
input: `<script>alert("hello");</script><style>body { color: red; }</style><p>Clean text</p>`,
expected: "Clean text",
},
{
name: "Extracts links correctly",
input: `Visit my <a href="https://example.com">website</a> for info.`,
expected: "Visit my [website](https://example.com) for info.",
},
{
name: "Converts headers (H1, H2, H3)",
input: `<h1>Main Title</h1><h2>Subtitle</h2><h3>Section</h3>`,
expected: "# Main Title\n\n## Subtitle\n\n### Section",
},
{
name: "Handles bold and italics",
input: `Text <b>bold</b> and <strong>strong</strong>, then <i>italic</i> and <em>em</em>.`,
expected: "Text **bold** and **strong**, then *italic* and *em*.",
},
{
name: "Converts lists",
input: `<ul><li>First element</li><li>Second element</li></ul>`,
expected: "- First element\n- Second element",
},
{
name: "Handles paragraphs and line breaks (<br>)",
input: `<p>First paragraph</p><p>Second paragraph with<br>a line break.</p>`,
expected: "First paragraph\n\nSecond paragraph with\na line break.",
},
{
name: "Decodes HTML entities",
input: `Math: 5 &gt; 3 &amp; 2 &lt; 4. A &quot;quote&quot;.`,
expected: "Math: 5 > 3 & 2 < 4. A \"quote\".",
},
{
name: "Cleans up residual HTML tags",
input: `<div><span>Text inside div and span</span></div>`,
expected: "Text inside div and span",
},
{
name: "Removes multiple spaces and excessive empty lines",
input: `This text has too many spaces. <br><br><br><br> And too many newlines.`,
expected: "This text has too many spaces.\n\nAnd too many newlines.",
},
{
name: "Nested lists with indentation",
input: "<ul><li>One<ul><li>Two</li></ul></li></ul>",
// Expect the sub-element to have 4 spaces of indentation
expected: "- One\n - Two",
},
{
name: "Image support",
input: `<img src="image.jpg" alt="alternative text">`,
// Correct Markdown syntax for images
expected: "![alternative text](image.jpg)",
},
{
name: "Image support without alt-text",
input: `<img src="image.jpg">`,
// If alt is missing, square brackets remain empty
expected: "![](image.jpg)",
},
{
name: "XSS Bypass on Links (Obfuscated HTML entities)",
// The Go HTML parser resolves entities, so this becomes "javascript:alert(1)"
input: `<a href="jav&#x09;ascript:alert(1)">Click here</a>`,
// Our isSafeHref (if updated with net/url) should neutralize it to "#"
expected: "[Click here](#)",
},
{
name: "Empty link or used as anchor",
input: `<a name="top"></a>`,
// With no text or href, it shouldn't print anything (not even empty brackets)
expected: "",
},
{
name: "Link without href but with text (Textual anchor)",
input: `<a id="top">Back to top</a>`,
// Should extract only plain text, without generating a broken Markdown link like [Back to top](#) or [Back to top]()
expected: "Back to top",
},
{
name: "Badly spaced bold and italics (Edge Case)",
input: `<b> Text </b>`,
// In Markdown `** Text **` is often not formatted correctly. The ideal is `**Text**`
expected: "**Text**",
},
{
name: "Complex Test - Real Article",
input: `
<h1>Article Title</h1>
<p>This is an <strong>introductory text</strong> with a <a href="http://link.com">link</a>.</p>
<h2>Subtitle</h2>
<ul>
<li>Point one</li>
<li>Point two</li>
</ul>
<script>console.log("do not show me")</script>
`,
// Note: The indentation of the real HTML test will generate spaces that
// regex will clean up.
expected: "# Article Title\n\nThis is an **introductory text** with a [link](http://link.com).\n\n## Subtitle\n\n- Point one\n- Point two",
},
{
name: "Ordered list (OL)",
input: `<ol><li>First</li><li>Second</li><li>Third</li></ol>`,
expected: "1. First\n2. Second\n3. Third",
},
{
name: "Ordered list nested in unordered list",
input: `<ul><li>Fruits<ol><li>Apples</li><li>Pears</li></ol></li><li>Vegetables</li></ul>`,
expected: "- Fruits\n 1. Apples\n 2. Pears\n- Vegetables",
},
{
name: "Code block (pre/code)",
input: "<pre><code>func main() {\n fmt.Println(\"hello\")\n}</code></pre>",
expected: "```\nfunc main() {\n fmt.Println(\"hello\")\n}\n```",
},
{
name: "Inline code",
input: `<p>Use the command <code>go test ./...</code> to run the tests.</p>`,
expected: "Use the command `go test ./...` to run the tests.",
},
{
name: "Simple blockquote",
input: `<blockquote><p>An important quote.</p></blockquote>`,
expected: "> An important quote.",
},
{
name: "Multiline blockquote",
input: `<blockquote><p>First line of the quote.</p><p>Second line of the quote.</p></blockquote>`,
expected: "> First line of the quote.\n>\n> Second line of the quote.",
},
{
name: "Strikethrough text (del/s)",
input: `This text is <del>deleted</del> and this is <s>crossed out</s>.`,
expected: "This text is ~~deleted~~ and this is ~~crossed out~~.",
},
{
name: "Horizontal separator (HR)",
input: `<p>Above the line</p><hr><p>Below the line</p>`,
expected: "Above the line\n\n---\n\nBelow the line",
},
{
name: "Bold nested in link",
input: `<a href="https://example.com"><strong>Linked bold text</strong></a>`,
expected: "[**Linked bold text**](https://example.com)",
},
{
name: "data-src Image (lazy loading)",
input: `<img data-src="lazy.jpg" alt="Lazy image">`,
expected: "![Lazy image](lazy.jpg)",
},
{
name: "Image with javascript: src blocked",
input: `<img src="javascript:alert(1)" alt="XSS">`,
// src is not safe, so the image is not emitted
expected: "",
},
{
name: "Link with data: href blocked",
input: `<a href="data:text/html,<script>alert(1)</script>">Click</a>`,
expected: "[Click](#)",
},
{
name: "Deeply nested divs",
input: `<div><div><div><div><p>Deeply nested text</p></div></div></div></div>`,
expected: "Deeply nested text",
},
{
name: "Non-consecutive headers (H1, H3, H5)",
input: `<h1>Title</h1><h3>Subsection</h3><h5>Sub-subsection</h5>`,
expected: "# Title\n\n### Subsection\n\n##### Sub-subsection",
},
{
name: "Paragraph with mixed multiple emphasis",
input: `<p><strong>Important:</strong> read the <strong><em>critical instructions</em></strong> <em>carefully</em>.</p>`,
expected: "**Important:** read the ***critical instructions*** *carefully*.",
},
{
name: "Article with nav and aside sections (noise to filter)",
input: `
<nav><a href="/home">Home</a><a href="/about-us">About us</a></nav>
<article>
<h2>Article title</h2>
<p>This is the body of the article.</p>
</article>
<aside><p>Advertisement</p></aside>
`,
expected: "## Article title\n\nThis is the body of the article.",
},
{
name: "Text with mixed special HTML entities",
input: `Copyright &copy; 2024 &mdash; All rights reserved &reg;`,
expected: "Copyright © 2024 — All rights reserved ®",
},
{
name: "Mailto link",
input: `Write to us at <a href="mailto:info@example.com">info@example.com</a>`,
expected: "Write to us at [info@example.com](mailto:info@example.com)",
},
{
name: "Image inside a link (clickable figure)",
input: `<a href="https://example.com"><img src="photo.jpg" alt="Photo"></a>`,
// The image-link without text must not generate broken markup
expected: "[![Photo](photo.jpg)](https://example.com)",
},
{
name: "Empty content or only whitespace",
input: ` <p> </p> <div> </div> `,
expected: "",
},
}
// Iterate over all test cases
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := HtmlToMarkdown(tt.input)
if err != nil {
logger.ErrorCF("tool", "Failed to parse html to markdown: %s", map[string]any{"error": err.Error()})
}
if got != tt.expected {
t.Errorf("\nTest case failed: %s\nInput: %q\nGot: %q\nExpected: %q",
tt.name, tt.input, got, tt.expected)
}
})
}
}