mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'upstream-main' into feat/subturn-poc
This commit is contained in:
+24
-5
@@ -52,7 +52,7 @@ func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuil
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
@@ -65,7 +65,7 @@ func getGlobalConfigDir() string {
|
||||
func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
// builtin skills: skills directory in current project
|
||||
// Use the skills/ directory under the current working directory
|
||||
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
|
||||
builtinSkillsDir := strings.TrimSpace(os.Getenv(config.EnvBuiltinSkills))
|
||||
if builtinSkillsDir == "" {
|
||||
wd, _ := os.Getwd()
|
||||
builtinSkillsDir = filepath.Join(wd, "skills")
|
||||
@@ -458,7 +458,23 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string {
|
||||
//
|
||||
// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
// See: https://platform.openai.com/docs/guides/prompt-caching
|
||||
func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
|
||||
func formatCurrentSenderLine(senderID, senderDisplayName string) string {
|
||||
senderID = strings.TrimSpace(senderID)
|
||||
senderDisplayName = strings.TrimSpace(senderDisplayName)
|
||||
|
||||
switch {
|
||||
case senderDisplayName != "" && senderID != "":
|
||||
return fmt.Sprintf("Current sender: %s (ID: %s)", senderDisplayName, senderID)
|
||||
case senderDisplayName != "":
|
||||
return fmt.Sprintf("Current sender: %s", senderDisplayName)
|
||||
case senderID != "":
|
||||
return fmt.Sprintf("Current sender: %s", senderID)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) buildDynamicContext(channel, chatID, senderID, senderDisplayName string) string {
|
||||
now := time.Now().Format("2006-01-02 15:04 (Monday)")
|
||||
rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version())
|
||||
|
||||
@@ -468,6 +484,9 @@ func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
|
||||
if channel != "" && chatID != "" {
|
||||
fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID)
|
||||
}
|
||||
if senderLine := formatCurrentSenderLine(senderID, senderDisplayName); senderLine != "" {
|
||||
fmt.Fprintf(&sb, "\n\n## Current Sender\n%s", senderLine)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -477,7 +496,7 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
summary string,
|
||||
currentMessage string,
|
||||
media []string,
|
||||
channel, chatID string,
|
||||
channel, chatID, senderID, senderDisplayName string,
|
||||
) []providers.Message {
|
||||
messages := []providers.Message{}
|
||||
|
||||
@@ -493,7 +512,7 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
staticPrompt := cb.BuildSystemPromptWithCache()
|
||||
|
||||
// Build short dynamic context (time, runtime, session) — changes per request
|
||||
dynamicCtx := cb.buildDynamicContext(channel, chatID)
|
||||
dynamicCtx := cb.buildDynamicContext(channel, chatID, senderID, senderDisplayName)
|
||||
|
||||
// Compose a single system message: static (cached) + dynamic + optional summary.
|
||||
// Keeping all system content in one message ensures every provider adapter can
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestSingleSystemMessage(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1")
|
||||
msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1", "", "")
|
||||
|
||||
systemCount := 0
|
||||
for _, m := range msgs {
|
||||
@@ -126,6 +126,68 @@ func TestSingleSystemMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessages_CurrentSenderDynamicContext(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"IDENTITY.md": "# Identity\nTest agent.",
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
senderID string
|
||||
senderDisplayName string
|
||||
wantLine string
|
||||
wantSection bool
|
||||
}{
|
||||
{
|
||||
name: "both id and display name",
|
||||
senderID: "feishu:ou_xxx",
|
||||
senderDisplayName: "Zhang San",
|
||||
wantLine: "Current sender: Zhang San (ID: feishu:ou_xxx)",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "display name only",
|
||||
senderDisplayName: "Alice",
|
||||
wantLine: "Current sender: Alice",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "id only",
|
||||
senderID: "discord:123",
|
||||
wantLine: "Current sender: discord:123",
|
||||
wantSection: true,
|
||||
},
|
||||
{
|
||||
name: "no sender info",
|
||||
wantSection: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "discord", "chat1", tt.senderID, tt.senderDisplayName)
|
||||
sys := msgs[0].Content
|
||||
|
||||
if tt.wantSection {
|
||||
if !strings.Contains(sys, "## Current Sender") {
|
||||
t.Fatalf("system prompt missing Current Sender section:\n%s", sys)
|
||||
}
|
||||
if !strings.Contains(sys, tt.wantLine) {
|
||||
t.Fatalf("system prompt missing sender line %q:\n%s", tt.wantLine, sys)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(sys, "## Current Sender") {
|
||||
t.Fatalf("system prompt should omit Current Sender section:\n%s", sys)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMtimeAutoInvalidation verifies that the cache detects source file changes
|
||||
// via mtime without requiring explicit InvalidateCache().
|
||||
// Fix: original implementation had no auto-invalidation — edits to bootstrap files,
|
||||
@@ -576,7 +638,7 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// Also exercise BuildMessages concurrently
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat")
|
||||
msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat", "", "")
|
||||
if len(msgs) < 2 {
|
||||
errs <- "BuildMessages returned fewer than 2 messages"
|
||||
return
|
||||
@@ -664,6 +726,6 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test")
|
||||
_ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test", "", "")
|
||||
}
|
||||
}
|
||||
|
||||
+72
-15
@@ -61,6 +61,8 @@ type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
SenderID string // Current sender ID for dynamic context
|
||||
SenderDisplayName string // Current sender display name for dynamic context
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
@@ -166,7 +168,12 @@ func registerSharedTools(
|
||||
}
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("web_fetch") {
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(
|
||||
50000,
|
||||
cfg.Tools.Web.Proxy,
|
||||
cfg.Tools.Web.Format,
|
||||
cfg.Tools.Web.FetchLimitBytes,
|
||||
cfg.Tools.Web.PrivateHostWhitelist)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
@@ -338,10 +345,9 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
case msg, ok := <-al.bus.InboundChan():
|
||||
if !ok {
|
||||
continue
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start a goroutine that drains the bus while processMessage is
|
||||
@@ -408,6 +414,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
}()
|
||||
default:
|
||||
time.Sleep(time.Microsecond * 200)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -419,9 +427,15 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
// is active and stops when drainCtx is canceled (i.e., processMessage returns).
|
||||
func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
|
||||
for {
|
||||
msg, ok := al.bus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
var msg bus.InboundMessage
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case m, ok := <-al.bus.InboundChan():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
msg = m
|
||||
}
|
||||
|
||||
// Transcribe audio if needed before steering, so the agent sees text.
|
||||
@@ -861,14 +875,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
})
|
||||
|
||||
opts := processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
SessionKey: sessionKey,
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
SenderID: msg.SenderID,
|
||||
SenderDisplayName: msg.Sender.DisplayName,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
}
|
||||
|
||||
// context-dependent commands check their own Runtime fields and report
|
||||
@@ -1039,6 +1055,8 @@ func (al *AgentLoop) runAgentLoop(
|
||||
opts.Media,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
opts.SenderID,
|
||||
opts.SenderDisplayName,
|
||||
)
|
||||
|
||||
// Resolve media:// refs: images→base64 data URLs, non-images→local paths in content
|
||||
@@ -1256,6 +1274,19 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// Build tool definitions
|
||||
providerToolDefs := agent.Tools.ToProviderDefs()
|
||||
|
||||
// Determine whether the provider's native web search should replace
|
||||
// the client-side web_search tool for this request. Only enable when web
|
||||
// search is actually enabled and registered (so users who disabled web
|
||||
// access do not get provider-side search or billing).
|
||||
_, hasWebSearch := agent.Tools.Get("web_search")
|
||||
useNativeSearch := al.cfg.Tools.Web.PreferNative &&
|
||||
isNativeSearchProvider(agent.Provider) &&
|
||||
hasWebSearch
|
||||
|
||||
if useNativeSearch {
|
||||
providerToolDefs = filterClientWebSearch(providerToolDefs)
|
||||
}
|
||||
|
||||
// Log LLM request details
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
@@ -1264,6 +1295,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"model": activeModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"native_search": useNativeSearch,
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"system_prompt_len": len(messages[0].Content),
|
||||
@@ -1286,6 +1318,9 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
}
|
||||
if useNativeSearch {
|
||||
llmOpts["native_search"] = true
|
||||
}
|
||||
// parseThinkingLevel guarantees ThinkingOff for empty/unknown values,
|
||||
// so checking != ThinkingOff is sufficient.
|
||||
if agent.ThinkingLevel != ThinkingOff {
|
||||
@@ -1387,7 +1422,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
|
||||
messages = agent.ContextBuilder.BuildMessages(
|
||||
newHistory, newSummary, "",
|
||||
nil, opts.Channel, opts.ChatID,
|
||||
nil, opts.Channel, opts.ChatID, opts.SenderID, opts.SenderDisplayName,
|
||||
)
|
||||
continue
|
||||
}
|
||||
@@ -2246,6 +2281,28 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
|
||||
}
|
||||
|
||||
// isNativeSearchProvider reports whether the given LLM provider implements
|
||||
// NativeSearchCapable and returns true for SupportsNativeSearch.
|
||||
func isNativeSearchProvider(p providers.LLMProvider) bool {
|
||||
if ns, ok := p.(providers.NativeSearchCapable); ok {
|
||||
return ns.SupportsNativeSearch()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// filterClientWebSearch returns a copy of tools with the client-side
|
||||
// web_search tool removed. Used when native provider search is preferred.
|
||||
func filterClientWebSearch(tools []providers.ToolDefinition) []providers.ToolDefinition {
|
||||
result := make([]providers.ToolDefinition, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
if strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
result = append(result, t)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper to extract provider from registry for cleanup
|
||||
func extractProvider(registry *AgentRegistry) (providers.LLMProvider, bool) {
|
||||
if registry == nil {
|
||||
|
||||
+228
-39
@@ -30,6 +30,28 @@ func (f *fakeChannel) IsAllowed(string) bool {
|
||||
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
|
||||
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
|
||||
|
||||
type recordingProvider struct {
|
||||
lastMessages []providers.Message
|
||||
}
|
||||
|
||||
func (r *recordingProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
r.lastMessages = append([]providers.Message(nil), messages...)
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *recordingProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func newTestAgentLoop(
|
||||
t *testing.T,
|
||||
) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
|
||||
@@ -54,6 +76,59 @@ func newTestAgentLoop(
|
||||
return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
|
||||
}
|
||||
|
||||
func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "discord",
|
||||
SenderID: "discord:123",
|
||||
Sender: bus.SenderInfo{
|
||||
DisplayName: "Alice",
|
||||
},
|
||||
ChatID: "group-1",
|
||||
Content: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
if response != "Mock response" {
|
||||
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
|
||||
}
|
||||
if len(provider.lastMessages) == 0 {
|
||||
t.Fatal("provider did not receive any messages")
|
||||
}
|
||||
|
||||
systemPrompt := provider.lastMessages[0].Content
|
||||
wantSender := "## Current Sender\nCurrent sender: Alice (ID: discord:123)"
|
||||
if !strings.Contains(systemPrompt, wantSender) {
|
||||
t.Fatalf("system prompt missing sender context %q:\n%s", wantSender, systemPrompt)
|
||||
}
|
||||
|
||||
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
|
||||
if lastMessage.Role != "user" || lastMessage.Content != "hello" {
|
||||
t.Fatalf("last provider message = %+v, want unchanged user message", lastMessage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
@@ -922,10 +997,25 @@ func TestHandleReasoning(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
al.handleReasoning(context.Background(), "reasoning", "telegram", "")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if msg, ok := msgBus.SubscribeOutbound(ctx); ok {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-msgBus.OutboundChan():
|
||||
if !ok {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
}
|
||||
if msg.Content == "reasoning" {
|
||||
t.Fatalf("expected no message for empty chatID, got %+v", msg)
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
t.Log("expected an outbound message, got none within timeout")
|
||||
return
|
||||
default:
|
||||
// Continue to check for message
|
||||
time.Sleep(5 * time.Millisecond) // Avoid busy loop
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -933,9 +1023,7 @@ func TestHandleReasoning(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
msg, ok := <-msgBus.OutboundChan()
|
||||
if !ok {
|
||||
t.Fatal("expected an outbound message")
|
||||
}
|
||||
@@ -949,35 +1037,52 @@ func TestHandleReasoning(t *testing.T) {
|
||||
reasoning := "hello telegram reasoning"
|
||||
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("expected an outbound message, got none within timeout")
|
||||
return
|
||||
case msg, ok := <-msgBus.OutboundChan():
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
|
||||
if msg.Channel != "telegram" {
|
||||
t.Fatalf("expected telegram channel message, got %+v", msg)
|
||||
}
|
||||
if msg.ChatID != "tg-chat" {
|
||||
t.Fatalf("expected chatID tg-chat, got %+v", msg)
|
||||
}
|
||||
if msg.Content != reasoning {
|
||||
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
|
||||
if msg.Channel != "telegram" {
|
||||
t.Fatalf("expected telegram channel message, got %+v", msg)
|
||||
}
|
||||
if msg.ChatID != "tg-chat" {
|
||||
t.Fatalf("expected chatID tg-chat, got %+v", msg)
|
||||
}
|
||||
if msg.Content != reasoning {
|
||||
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("expired ctx", func(t *testing.T) {
|
||||
al, msgBus := newLoop(t)
|
||||
reasoning := "hello telegram reasoning"
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
msg, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if ok {
|
||||
t.Fatalf("expected no outbound message, got %+v", msg)
|
||||
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
|
||||
|
||||
consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer consumeCancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg, ok := <-msgBus.OutboundChan():
|
||||
if !ok {
|
||||
t.Fatalf("expected no outbound message, but received: %+v", msg)
|
||||
}
|
||||
t.Logf("Received unexpected outbound message: %+v", msg)
|
||||
return
|
||||
case <-consumeCtx.Done():
|
||||
t.Fatalf("failed: no message received within timeout")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1017,20 +1122,23 @@ func TestHandleReasoning(t *testing.T) {
|
||||
|
||||
// Drain the bus and verify the reasoning message was NOT published
|
||||
// (it should have been dropped due to timeout).
|
||||
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer drainCancel()
|
||||
foundReasoning := false
|
||||
timeer := time.After(1 * time.Second)
|
||||
for {
|
||||
msg, ok := msgBus.SubscribeOutbound(drainCtx)
|
||||
if !ok {
|
||||
break
|
||||
select {
|
||||
case <-timeer:
|
||||
t.Logf(
|
||||
"no reasoning message received after draining bus for 1s, as expected,length=%d",
|
||||
len(msgBus.OutboundChan()),
|
||||
)
|
||||
return
|
||||
case msg, ok := <-msgBus.OutboundChan():
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if msg.Content == "should timeout" {
|
||||
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
|
||||
}
|
||||
}
|
||||
if msg.Content == "should timeout" {
|
||||
foundReasoning = true
|
||||
}
|
||||
}
|
||||
if foundReasoning {
|
||||
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1318,3 +1426,84 @@ func TestResolveMediaRefs_MixedImageAndFile(t *testing.T) {
|
||||
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Native search helper tests ---
|
||||
|
||||
type nativeSearchProvider struct {
|
||||
supported bool
|
||||
}
|
||||
|
||||
func (p *nativeSearchProvider) Chat(
|
||||
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
|
||||
model string, opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{Content: "ok"}, nil
|
||||
}
|
||||
|
||||
func (p *nativeSearchProvider) GetDefaultModel() string { return "test-model" }
|
||||
|
||||
func (p *nativeSearchProvider) SupportsNativeSearch() bool { return p.supported }
|
||||
|
||||
type plainProvider struct{}
|
||||
|
||||
func (p *plainProvider) Chat(
|
||||
ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition,
|
||||
model string, opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{Content: "ok"}, nil
|
||||
}
|
||||
|
||||
func (p *plainProvider) GetDefaultModel() string { return "test-model" }
|
||||
|
||||
func TestIsNativeSearchProvider_Supported(t *testing.T) {
|
||||
if !isNativeSearchProvider(&nativeSearchProvider{supported: true}) {
|
||||
t.Fatal("expected true for provider that supports native search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNativeSearchProvider_NotSupported(t *testing.T) {
|
||||
if isNativeSearchProvider(&nativeSearchProvider{supported: false}) {
|
||||
t.Fatal("expected false for provider that does not support native search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNativeSearchProvider_NoInterface(t *testing.T) {
|
||||
if isNativeSearchProvider(&plainProvider{}) {
|
||||
t.Fatal("expected false for provider that does not implement NativeSearchCapable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterClientWebSearch_RemovesWebSearch(t *testing.T) {
|
||||
defs := []providers.ToolDefinition{
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "web_search"}},
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
|
||||
}
|
||||
result := filterClientWebSearch(defs)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
for _, td := range result {
|
||||
if td.Function.Name == "web_search" {
|
||||
t.Fatal("web_search should be filtered out")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterClientWebSearch_NoWebSearch(t *testing.T) {
|
||||
defs := []providers.ToolDefinition{
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}},
|
||||
{Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}},
|
||||
}
|
||||
result := filterClientWebSearch(defs)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterClientWebSearch_EmptyInput(t *testing.T) {
|
||||
result := filterClientWebSearch(nil)
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("len(result) = %d, want 0", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
+2
-1
@@ -6,6 +6,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
|
||||
@@ -39,7 +40,7 @@ func (c *AuthCredential) NeedsRefresh() bool {
|
||||
}
|
||||
|
||||
func authFilePath() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return filepath.Join(home, "auth.json")
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
|
||||
+60
-93
@@ -3,6 +3,7 @@ package bus
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -17,8 +18,11 @@ type MessageBus struct {
|
||||
inbound chan InboundMessage
|
||||
outbound chan OutboundMessage
|
||||
outboundMedia chan OutboundMediaMessage
|
||||
done chan struct{}
|
||||
closed atomic.Bool
|
||||
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewMessageBus() *MessageBus {
|
||||
@@ -30,128 +34,91 @@ func NewMessageBus() *MessageBus {
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
|
||||
func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error {
|
||||
// check bus closed before acquiring wg, to avoid unnecessary wg.Add and potential deadlock
|
||||
if mb.closed.Load() {
|
||||
return ErrBusClosed
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check again,before sending message, to avoid sending to closed channel
|
||||
select {
|
||||
case mb.inbound <- msg:
|
||||
return nil
|
||||
case <-mb.done:
|
||||
return ErrBusClosed
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-mb.done:
|
||||
return ErrBusClosed
|
||||
default:
|
||||
}
|
||||
|
||||
mb.wg.Add(1)
|
||||
defer mb.wg.Done()
|
||||
|
||||
select {
|
||||
case ch <- msg:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-mb.done:
|
||||
return ErrBusClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) {
|
||||
select {
|
||||
case msg, ok := <-mb.inbound:
|
||||
return msg, ok
|
||||
case <-mb.done:
|
||||
return InboundMessage{}, false
|
||||
case <-ctx.Done():
|
||||
return InboundMessage{}, false
|
||||
}
|
||||
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
|
||||
return publish(ctx, mb, mb.inbound, msg)
|
||||
}
|
||||
|
||||
func (mb *MessageBus) InboundChan() <-chan InboundMessage {
|
||||
return mb.inbound
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
|
||||
if mb.closed.Load() {
|
||||
return ErrBusClosed
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case mb.outbound <- msg:
|
||||
return nil
|
||||
case <-mb.done:
|
||||
return ErrBusClosed
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
return publish(ctx, mb, mb.outbound, msg)
|
||||
}
|
||||
|
||||
func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) {
|
||||
select {
|
||||
case msg, ok := <-mb.outbound:
|
||||
return msg, ok
|
||||
case <-mb.done:
|
||||
return OutboundMessage{}, false
|
||||
case <-ctx.Done():
|
||||
return OutboundMessage{}, false
|
||||
}
|
||||
func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
|
||||
return mb.outbound
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
|
||||
if mb.closed.Load() {
|
||||
return ErrBusClosed
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case mb.outboundMedia <- msg:
|
||||
return nil
|
||||
case <-mb.done:
|
||||
return ErrBusClosed
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
return publish(ctx, mb, mb.outboundMedia, msg)
|
||||
}
|
||||
|
||||
func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) {
|
||||
select {
|
||||
case msg, ok := <-mb.outboundMedia:
|
||||
return msg, ok
|
||||
case <-mb.done:
|
||||
return OutboundMediaMessage{}, false
|
||||
case <-ctx.Done():
|
||||
return OutboundMediaMessage{}, false
|
||||
}
|
||||
func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
|
||||
return mb.outboundMedia
|
||||
}
|
||||
|
||||
func (mb *MessageBus) Close() {
|
||||
if mb.closed.CompareAndSwap(false, true) {
|
||||
mb.closeOnce.Do(func() {
|
||||
// notify all blocked publishers to exit
|
||||
close(mb.done)
|
||||
|
||||
// Drain buffered channels so messages aren't silently lost.
|
||||
// Channels are NOT closed to avoid send-on-closed panics from concurrent publishers.
|
||||
// because every publisher will check mb.closed before acquiring wg
|
||||
// so we can be sure that new publishers will not be added new messages after this point
|
||||
mb.closed.Store(true)
|
||||
|
||||
// wait for all ongoing Publish calls to finish, ensuring all messages have been sent to channels or exited
|
||||
mb.wg.Wait()
|
||||
|
||||
// close channels safely
|
||||
close(mb.inbound)
|
||||
close(mb.outbound)
|
||||
close(mb.outboundMedia)
|
||||
|
||||
// clean up any remaining messages in channels
|
||||
drained := 0
|
||||
for {
|
||||
select {
|
||||
case <-mb.inbound:
|
||||
drained++
|
||||
default:
|
||||
goto doneInbound
|
||||
}
|
||||
for range mb.inbound {
|
||||
drained++
|
||||
}
|
||||
doneInbound:
|
||||
for {
|
||||
select {
|
||||
case <-mb.outbound:
|
||||
drained++
|
||||
default:
|
||||
goto doneOutbound
|
||||
}
|
||||
for range mb.outbound {
|
||||
drained++
|
||||
}
|
||||
doneOutbound:
|
||||
for {
|
||||
select {
|
||||
case <-mb.outboundMedia:
|
||||
drained++
|
||||
default:
|
||||
goto doneMedia
|
||||
}
|
||||
for range mb.outboundMedia {
|
||||
drained++
|
||||
}
|
||||
doneMedia:
|
||||
|
||||
if drained > 0 {
|
||||
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
|
||||
"count": drained,
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+35
-17
@@ -24,7 +24,7 @@ func TestPublishConsume(t *testing.T) {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := mb.ConsumeInbound(ctx)
|
||||
got, ok := <-mb.InboundChan()
|
||||
if !ok {
|
||||
t.Fatal("ConsumeInbound returned ok=false")
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := mb.SubscribeOutbound(ctx)
|
||||
got, ok := <-mb.OutboundChan()
|
||||
if !ok {
|
||||
t.Fatal("SubscribeOutbound returned ok=false")
|
||||
}
|
||||
@@ -108,27 +108,48 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
|
||||
|
||||
func TestConsumeInbound_ContextCancel(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
|
||||
defer mb.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
_, ok := mb.ConsumeInbound(ctx)
|
||||
if ok {
|
||||
t.Fatal("expected ok=false when context is canceled")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Log("context canceled, as expected")
|
||||
|
||||
case msg, ok := <-mb.InboundChan():
|
||||
if !ok {
|
||||
t.Fatal("expected ok=false when context is canceled")
|
||||
}
|
||||
if msg.Content == "ContextCancel" {
|
||||
t.Fatalf("expected content 'ContextCancel', got %q", msg.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeInbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
timer := time.AfterFunc(100*time.Millisecond, func() {
|
||||
mb.Close()
|
||||
})
|
||||
|
||||
_, ok := mb.ConsumeInbound(ctx)
|
||||
if ok {
|
||||
t.Fatal("expected ok=false when bus is closed")
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Log("context canceled, as expected")
|
||||
|
||||
case _, ok := <-mb.InboundChan():
|
||||
if ok {
|
||||
t.Fatal("expected ok=false when context is canceled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,10 +157,7 @@ func TestSubscribeOutbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, ok := mb.SubscribeOutbound(ctx)
|
||||
_, ok := <-mb.OutboundChan()
|
||||
if ok {
|
||||
t.Fatal("expected ok=false when bus is closed")
|
||||
}
|
||||
|
||||
@@ -29,11 +29,17 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// errCodeTenantTokenInvalid is the Feishu API error code for an expired/revoked
|
||||
// tenant_access_token. The Lark SDK's built-in retry does not clear its cache
|
||||
// on this error, so we do it ourselves.
|
||||
const errCodeTenantTokenInvalid = 99991663
|
||||
|
||||
type FeishuChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.FeishuConfig
|
||||
client *lark.Client
|
||||
wsClient *larkws.Client
|
||||
config config.FeishuConfig
|
||||
client *lark.Client
|
||||
wsClient *larkws.Client
|
||||
tokenCache *tokenCache // custom cache that supports invalidation
|
||||
|
||||
botOpenID atomic.Value // stores string; populated lazily for @mention detection
|
||||
|
||||
@@ -47,10 +53,12 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
tc := newTokenCache()
|
||||
ch := &FeishuChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret),
|
||||
tokenCache: tc,
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret, lark.WithTokenCache(tc)),
|
||||
}
|
||||
ch.SetOwner(ch)
|
||||
return ch, nil
|
||||
@@ -147,6 +155,7 @@ func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, cont
|
||||
return fmt.Errorf("feishu edit: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
@@ -186,6 +195,7 @@ func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (str
|
||||
return "", fmt.Errorf("feishu placeholder send: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
|
||||
@@ -226,6 +236,7 @@ func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID st
|
||||
return func() {}, fmt.Errorf("feishu react: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
logger.ErrorCF("feishu", "Reaction API error", map[string]any{
|
||||
"emoji": chosenEmoji,
|
||||
"message_id": messageID,
|
||||
@@ -451,6 +462,7 @@ func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error {
|
||||
return fmt.Errorf("bot info parse: %w", err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
c.invalidateTokenOnAuthError(result.Code)
|
||||
return fmt.Errorf("bot info api error (code=%d)", result.Code)
|
||||
}
|
||||
if result.Bot.OpenID == "" {
|
||||
@@ -593,6 +605,7 @@ func (c *FeishuChannel) downloadResource(
|
||||
return ""
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
logger.ErrorCF("feishu", "Resource download api error", map[string]any{
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
@@ -705,6 +718,7 @@ func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
}
|
||||
|
||||
@@ -730,6 +744,7 @@ func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.F
|
||||
return fmt.Errorf("feishu image upload: %w", err)
|
||||
}
|
||||
if !uploadResp.Success() {
|
||||
c.invalidateTokenOnAuthError(uploadResp.Code)
|
||||
return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
|
||||
}
|
||||
if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil {
|
||||
@@ -754,6 +769,7 @@ func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.F
|
||||
return fmt.Errorf("feishu image send: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
@@ -784,6 +800,7 @@ func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.Fi
|
||||
return fmt.Errorf("feishu file upload: %w", err)
|
||||
}
|
||||
if !uploadResp.Success() {
|
||||
c.invalidateTokenOnAuthError(uploadResp.Code)
|
||||
return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
|
||||
}
|
||||
if uploadResp.Data == nil || uploadResp.Data.FileKey == nil {
|
||||
@@ -808,6 +825,7 @@ func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.Fi
|
||||
return fmt.Errorf("feishu file send: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
@@ -830,3 +848,14 @@ func extractFeishuSenderID(sender *larkim.EventSender) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// invalidateTokenOnAuthError clears the cached tenant_access_token when the
|
||||
// Feishu API reports it as invalid (99991663), so the next request fetches a
|
||||
// fresh one. The Lark SDK's built-in retry does not clear the cache, causing
|
||||
// all API calls to fail until the token naturally expires (~2 hours).
|
||||
func (c *FeishuChannel) invalidateTokenOnAuthError(code int) {
|
||||
if code == errCodeTenantTokenInvalid {
|
||||
c.tokenCache.InvalidateAll()
|
||||
logger.WarnCF("feishu", "Invalidated cached token due to auth error", nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// tokenCache implements larkcore.Cache with an extra InvalidateAll method.
|
||||
// This works around a bug in the Lark SDK v3 where the built-in token retry
|
||||
// loop does not clear stale tokens from cache on auth errors.
|
||||
type tokenCache struct {
|
||||
mu sync.RWMutex
|
||||
store map[string]*tokenEntry
|
||||
}
|
||||
|
||||
type tokenEntry struct {
|
||||
value string
|
||||
expireAt time.Time
|
||||
}
|
||||
|
||||
func newTokenCache() *tokenCache {
|
||||
return &tokenCache{store: make(map[string]*tokenEntry)}
|
||||
}
|
||||
|
||||
func (c *tokenCache) Set(_ context.Context, key, value string, ttl time.Duration) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.store[key] = &tokenEntry{value: value, expireAt: time.Now().Add(ttl)}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *tokenCache) Get(_ context.Context, key string) (string, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
e, ok := c.store[key]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
if e.expireAt.Before(time.Now()) {
|
||||
delete(c.store, key)
|
||||
return "", nil
|
||||
}
|
||||
return e.value, nil
|
||||
}
|
||||
|
||||
// InvalidateAll removes all cached tokens, forcing fresh acquisition.
|
||||
func (c *tokenCache) InvalidateAll() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
clear(c.store)
|
||||
}
|
||||
+33
-27
@@ -585,7 +585,7 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
|
||||
func dispatchLoop[M any](
|
||||
ctx context.Context,
|
||||
m *Manager,
|
||||
subscribe func(context.Context) (M, bool),
|
||||
ch <-chan M,
|
||||
getChannel func(M) string,
|
||||
enqueue func(context.Context, *channelWorker, M) bool,
|
||||
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
|
||||
@@ -593,35 +593,41 @@ func dispatchLoop[M any](
|
||||
logger.InfoC("channels", startMsg)
|
||||
|
||||
for {
|
||||
msg, ok := subscribe(ctx)
|
||||
if !ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.InfoC("channels", stopMsg)
|
||||
return
|
||||
}
|
||||
|
||||
channel := getChannel(msg)
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[channel]
|
||||
w, wExists := m.workers[channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
|
||||
continue
|
||||
}
|
||||
|
||||
if wExists && w != nil {
|
||||
if !enqueue(ctx, w, msg) {
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
logger.InfoC("channels", stopMsg)
|
||||
return
|
||||
}
|
||||
} else if exists {
|
||||
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
|
||||
|
||||
channel := getChannel(msg)
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[channel]
|
||||
w, wExists := m.workers[channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
|
||||
continue
|
||||
}
|
||||
|
||||
if wExists && w != nil {
|
||||
if !enqueue(ctx, w, msg) {
|
||||
return
|
||||
}
|
||||
} else if exists {
|
||||
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -629,7 +635,7 @@ func dispatchLoop[M any](
|
||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.SubscribeOutbound,
|
||||
m.bus.OutboundChan(),
|
||||
func(msg bus.OutboundMessage) string { return msg.Channel },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
@@ -649,7 +655,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.SubscribeOutboundMedia,
|
||||
m.bus.OutboundMediaChan(),
|
||||
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
|
||||
@@ -34,11 +34,19 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message")
|
||||
}
|
||||
if inbound.Metadata["account_id"] != "7750283E123456" {
|
||||
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for inbound message")
|
||||
return
|
||||
case inbound, ok := <-messageBus.InboundChan():
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message")
|
||||
}
|
||||
if inbound.Metadata["account_id"] != "7750283E123456" {
|
||||
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// mdV2SpecialChars are all characters that must be escaped in Telegram MarkdownV2
|
||||
var mdV2SpecialChars = map[rune]bool{
|
||||
'*': true,
|
||||
'_': true,
|
||||
'[': true,
|
||||
']': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'~': true,
|
||||
'`': true,
|
||||
'>': true,
|
||||
'<': true,
|
||||
'#': true,
|
||||
'+': true,
|
||||
'-': true,
|
||||
'=': true,
|
||||
'|': true,
|
||||
'{': true,
|
||||
'}': true,
|
||||
'.': true,
|
||||
'!': true,
|
||||
'\\': true,
|
||||
}
|
||||
|
||||
// entityPattern describes one Telegram MarkdownV2 inline entity type.
|
||||
type entityPattern struct {
|
||||
re *regexp.Regexp
|
||||
open string
|
||||
close string
|
||||
}
|
||||
|
||||
// allEntityPatterns lists every recognized entity in priority order
|
||||
// (longer / more-specific delimiters first so they win over shorter ones).
|
||||
// Each entry's regex is anchored to find the first occurrence in a string.
|
||||
var allEntityPatterns = []entityPattern{
|
||||
// fenced code block — content is completely verbatim
|
||||
{re: regexp.MustCompile("(?s)```(?:[\\w]*\\n)?[\\s\\S]*?```"), open: "```", close: "```"},
|
||||
// inline code — content is completely verbatim
|
||||
{re: regexp.MustCompile("`(?:[^`\\\n]|\\\\.)*`"), open: "`", close: "`"},
|
||||
// expandable block-quote opener **>…
|
||||
{re: regexp.MustCompile(`(?m)\*\*>(?:[^\n]*)`), open: "**>", close: ""},
|
||||
// block-quote line >…
|
||||
{re: regexp.MustCompile(`(?m)^>(?:[^\n]*)`), open: ">", close: ""},
|
||||
// custom emoji / timestamp  — must come before plain link
|
||||
{re: regexp.MustCompile(`!\[[^\]]*\]\([^)]*\)`), open: "!", close: ""},
|
||||
// inline URL / user mention […](…)
|
||||
{re: regexp.MustCompile(`\[[^\]]*\]\([^)]*\)`), open: "[", close: ""},
|
||||
// spoiler ||…|| — before single | so it wins
|
||||
{re: regexp.MustCompile(`\|\|(?:[^|\\\n]|\\.)*\|\|`), open: "||", close: "||"},
|
||||
// underline __…__ — before single _ so it wins
|
||||
{re: regexp.MustCompile(`__(?:[^_\\\n]|\\.)*__`), open: "__", close: "__"},
|
||||
// bold *…*
|
||||
{re: regexp.MustCompile(`\*(?:[^*\\\n]|\\.)*\*`), open: "*", close: "*"},
|
||||
// italic _…_
|
||||
{re: regexp.MustCompile(`_(?:[^_\\\n]|\\.)*_`), open: "_", close: "_"},
|
||||
// strikethrough ~…~
|
||||
{re: regexp.MustCompile(`~(?:[^~\\\n]|\\.)*~`), open: "~", close: "~"},
|
||||
}
|
||||
|
||||
// verbatimEntities are entity types whose inner content must never be
|
||||
// touched (code blocks, URLs, quotes, custom emoji).
|
||||
// Their content is passed through completely unchanged.
|
||||
var verbatimEntities = map[string]bool{
|
||||
"```": true,
|
||||
"`": true,
|
||||
"**>": true,
|
||||
">": true,
|
||||
"!": true,
|
||||
"[": true,
|
||||
}
|
||||
|
||||
// markdownToTelegramMarkdownV2 converts a Markdown string into a string safe
|
||||
// for sending with Telegram's MarkdownV2 parse mode.
|
||||
//
|
||||
// Rules:
|
||||
// - Markdown headings (# … ######) are converted to *bold*.
|
||||
// - **bold** Markdown syntax is converted to *bold*.
|
||||
// - Recognized Telegram MarkdownV2 entity spans are preserved; their inner
|
||||
// content is processed recursively so that nested valid entities are kept
|
||||
// intact while stray special characters are escaped.
|
||||
// - All plain-text segments have their MarkdownV2 special characters escaped.
|
||||
//
|
||||
// Reference: https://core.telegram.org/bots/api#formatting-options
|
||||
func markdownToTelegramMarkdownV2(text string) string {
|
||||
// 1. Convert Markdown headings → *escaped heading text*
|
||||
text = reHeading.ReplaceAllStringFunc(text, func(match string) string {
|
||||
sub := reHeading.FindStringSubmatch(match)
|
||||
if len(sub) < 2 {
|
||||
return match
|
||||
}
|
||||
// The heading content is fresh plain text — escape everything
|
||||
// including * so the resulting *…* bold span stays valid.
|
||||
return "*" + escapeMarkdownV2(sub[1]) + "*"
|
||||
})
|
||||
|
||||
// 2. Convert **bold** → *bold*
|
||||
text = reBoldStar.ReplaceAllString(text, "*$1*")
|
||||
|
||||
// 3. Recursively escape the full string.
|
||||
return processText(text)
|
||||
}
|
||||
|
||||
// processText walks `text`, finds the leftmost / longest matching entity,
|
||||
// escapes the gap before it, processes the entity (recursing into its inner
|
||||
// content when appropriate), then continues with the remainder.
|
||||
func processText(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the leftmost match among all entity patterns.
|
||||
bestStart := -1
|
||||
bestEnd := -1
|
||||
var bestPat *entityPattern
|
||||
|
||||
for i := range allEntityPatterns {
|
||||
p := &allEntityPatterns[i]
|
||||
loc := p.re.FindStringIndex(text)
|
||||
if loc == nil {
|
||||
continue
|
||||
}
|
||||
if bestStart == -1 || loc[0] < bestStart ||
|
||||
(loc[0] == bestStart && (loc[1]-loc[0]) > (bestEnd-bestStart)) {
|
||||
bestStart = loc[0]
|
||||
bestEnd = loc[1]
|
||||
bestPat = p
|
||||
}
|
||||
}
|
||||
|
||||
if bestPat == nil {
|
||||
// No entity found — escape everything.
|
||||
return escapeMarkdownV2(text)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Plain text before the entity.
|
||||
if bestStart > 0 {
|
||||
b.WriteString(escapeMarkdownV2(text[:bestStart]))
|
||||
}
|
||||
|
||||
// The matched entity span.
|
||||
matched := text[bestStart:bestEnd]
|
||||
|
||||
if verbatimEntities[bestPat.open] {
|
||||
// Code blocks, URLs, quotes: pass through completely untouched.
|
||||
b.WriteString(matched)
|
||||
} else {
|
||||
// Inline formatting (bold, italic, underline, strikethrough, spoiler):
|
||||
// keep the delimiters and recursively process the inner content so that
|
||||
// nested entities survive but stray specials get escaped.
|
||||
openLen := len(bestPat.open)
|
||||
closeLen := len(bestPat.close)
|
||||
inner := matched[openLen : len(matched)-closeLen]
|
||||
|
||||
b.WriteString(bestPat.open)
|
||||
b.WriteString(processText(inner))
|
||||
b.WriteString(bestPat.close)
|
||||
}
|
||||
|
||||
// Continue with the remainder of the string.
|
||||
b.WriteString(processText(text[bestEnd:]))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// escapeMarkdownV2 escapes every MarkdownV2 special character in a plain-text
|
||||
// segment (i.e. a segment that is not part of any recognized entity).
|
||||
// Already-escaped sequences (backslash + char) are forwarded verbatim to avoid
|
||||
// double-escaping.
|
||||
func escapeMarkdownV2(s string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s) + 8)
|
||||
runes := []rune(s)
|
||||
for i := 0; i < len(runes); i++ {
|
||||
ch := runes[i]
|
||||
// Forward an existing escape sequence verbatim.
|
||||
if ch == '\\' && i+1 < len(runes) {
|
||||
b.WriteRune(ch)
|
||||
b.WriteRune(runes[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if mdV2SpecialChars[ch] {
|
||||
b.WriteByte('\\')
|
||||
}
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//go:embed testdata/md2_all_formats.txt
|
||||
var md2AllFormats string
|
||||
|
||||
func Test_markdownToTelegramMarkdownV2(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "heading -> bolding",
|
||||
input: `## HeadingH2 #`,
|
||||
expected: "*HeadingH2 \\#*",
|
||||
},
|
||||
{
|
||||
name: "strikethrough",
|
||||
input: "~strikethroughMD~",
|
||||
expected: "~strikethroughMD~",
|
||||
},
|
||||
{
|
||||
name: "inline URL",
|
||||
input: "[inline URL](http://www.example.com/)",
|
||||
expected: "[inline URL](http://www.example.com/)",
|
||||
},
|
||||
{
|
||||
name: "all telegram formats",
|
||||
input: md2AllFormats,
|
||||
expected: md2AllFormats,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "one letter",
|
||||
input: "o",
|
||||
expected: "o",
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
input: "*Last update: ~10 24h*",
|
||||
expected: "*Last update: \\~10 24h*",
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
input: "<Market Capitalization>",
|
||||
expected: "\\<Market Capitalization\\>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := markdownToTelegramMarkdownV2(tc.input)
|
||||
|
||||
require.EqualValues(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func markdownToTelegramHTML(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
codeBlocks := extractCodeBlocks(text)
|
||||
text = codeBlocks.text
|
||||
|
||||
inlineCodes := extractInlineCodes(text)
|
||||
text = inlineCodes.text
|
||||
|
||||
text = reHeading.ReplaceAllString(text, "$1")
|
||||
|
||||
text = reBlockquote.ReplaceAllString(text, "$1")
|
||||
|
||||
text = escapeHTML(text)
|
||||
|
||||
text = reLink.ReplaceAllString(text, `<a href="$2">$1</a>`)
|
||||
|
||||
text = reBoldStar.ReplaceAllString(text, "<b>$1</b>")
|
||||
|
||||
text = reBoldUnder.ReplaceAllString(text, "<b>$1</b>")
|
||||
|
||||
text = reItalic.ReplaceAllStringFunc(text, func(s string) string {
|
||||
match := reItalic.FindStringSubmatch(s)
|
||||
if len(match) < 2 {
|
||||
return s
|
||||
}
|
||||
return "<i>" + match[1] + "</i>"
|
||||
})
|
||||
|
||||
text = reStrike.ReplaceAllString(text, "<s>$1</s>")
|
||||
|
||||
text = reListItem.ReplaceAllString(text, "• ")
|
||||
|
||||
for i, code := range inlineCodes.codes {
|
||||
escaped := escapeHTML(code)
|
||||
text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("<code>%s</code>", escaped))
|
||||
}
|
||||
|
||||
for i, code := range codeBlocks.codes {
|
||||
escaped := escapeHTML(code)
|
||||
text = strings.ReplaceAll(
|
||||
text,
|
||||
fmt.Sprintf("\x00CB%d\x00", i),
|
||||
fmt.Sprintf("<pre><code>%s</code></pre>", escaped),
|
||||
)
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
type codeBlockMatch struct {
|
||||
text string
|
||||
codes []string
|
||||
}
|
||||
|
||||
func extractCodeBlocks(text string) codeBlockMatch {
|
||||
matches := reCodeBlock.FindAllStringSubmatch(text, -1)
|
||||
|
||||
codes := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = reCodeBlock.ReplaceAllStringFunc(text, func(m string) string {
|
||||
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return codeBlockMatch{text: text, codes: codes}
|
||||
}
|
||||
|
||||
type inlineCodeMatch struct {
|
||||
text string
|
||||
codes []string
|
||||
}
|
||||
|
||||
func extractInlineCodes(text string) inlineCodeMatch {
|
||||
matches := reInlineCode.FindAllStringSubmatch(text, -1)
|
||||
|
||||
codes := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = reInlineCode.ReplaceAllStringFunc(text, func(m string) string {
|
||||
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return inlineCodeMatch{text: text, codes: codes}
|
||||
}
|
||||
|
||||
func escapeHTML(text string) string {
|
||||
text = strings.ReplaceAll(text, "&", "&")
|
||||
text = strings.ReplaceAll(text, "<", "<")
|
||||
text = strings.ReplaceAll(text, ">", ">")
|
||||
return text
|
||||
}
|
||||
+129
-128
@@ -3,6 +3,7 @@ package telegram
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -26,7 +27,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
reHeading = regexp.MustCompile(`^#{1,6}\s+(.+)$`)
|
||||
reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+([^\n]+)`)
|
||||
reBlockquote = regexp.MustCompile(`^>\s*(.*)$`)
|
||||
reLink = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
|
||||
reBoldStar = regexp.MustCompile(`\*\*(.+?)\*\*`)
|
||||
@@ -169,6 +170,8 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
|
||||
|
||||
chatID, threadID, err := parseTelegramChatID(msg.ChatID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
@@ -187,22 +190,65 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
chunk := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
htmlContent := markdownToTelegramHTML(chunk)
|
||||
content := parseContent(chunk, useMarkdownV2)
|
||||
|
||||
if len([]rune(htmlContent)) > 4096 {
|
||||
ratio := float64(len([]rune(chunk))) / float64(len([]rune(htmlContent)))
|
||||
if len([]rune(content)) > 4096 {
|
||||
runeChunk := []rune(chunk)
|
||||
ratio := float64(len(runeChunk)) / float64(len([]rune(content)))
|
||||
smallerLen := int(float64(4096) * ratio * 0.95) // 5% safety margin
|
||||
if smallerLen < 100 {
|
||||
smallerLen = 100
|
||||
|
||||
// Guarantee progress: if estimated length is >= chunk length, force it smaller
|
||||
if smallerLen >= len(runeChunk) {
|
||||
smallerLen = len(runeChunk) - 1
|
||||
}
|
||||
// Push sub-chunks back to the front of the queue for
|
||||
// re-validation instead of sending them blindly.
|
||||
|
||||
if smallerLen <= 0 {
|
||||
if err := c.sendChunk(ctx, sendChunkParams{
|
||||
chatID: chatID,
|
||||
threadID: threadID,
|
||||
content: content,
|
||||
replyToID: replyToID,
|
||||
mdFallback: chunk,
|
||||
useMarkdownV2: useMarkdownV2,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
replyToID = ""
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the estimated smaller length as a guide for SplitMessage.
|
||||
// SplitMessage will find natural break points (newlines/spaces) and respect code blocks.
|
||||
subChunks := channels.SplitMessage(chunk, smallerLen)
|
||||
queue = append(subChunks, queue...)
|
||||
|
||||
// Safety fallback: If SplitMessage failed to shorten the chunk, force a manual hard split.
|
||||
if len(subChunks) == 1 && subChunks[0] == chunk {
|
||||
part1 := string(runeChunk[:smallerLen])
|
||||
part2 := string(runeChunk[smallerLen:])
|
||||
subChunks = []string{part1, part2}
|
||||
}
|
||||
|
||||
// Filter out empty chunks to avoid sending empty messages to Telegram.
|
||||
nonEmpty := make([]string, 0, len(subChunks))
|
||||
for _, s := range subChunks {
|
||||
if s != "" {
|
||||
nonEmpty = append(nonEmpty, s)
|
||||
}
|
||||
}
|
||||
|
||||
// Push sub-chunks back to the front of the queue
|
||||
queue = append(nonEmpty, queue...)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.sendHTMLChunk(ctx, chatID, threadID, htmlContent, chunk, replyToID); err != nil {
|
||||
if err := c.sendChunk(ctx, sendChunkParams{
|
||||
chatID: chatID,
|
||||
threadID: threadID,
|
||||
content: content,
|
||||
replyToID: replyToID,
|
||||
mdFallback: chunk,
|
||||
useMarkdownV2: useMarkdownV2,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
// Only the first chunk should be a reply; subsequent chunks are normal messages.
|
||||
@@ -212,17 +258,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendHTMLChunk sends a single HTML message, falling back to the original
|
||||
// markdown as plain text on parse failure so users never see raw HTML tags.
|
||||
func (c *TelegramChannel) sendHTMLChunk(
|
||||
ctx context.Context, chatID int64, threadID int, htmlContent, mdFallback string, replyToID string,
|
||||
) error {
|
||||
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
|
||||
tgMsg.ParseMode = telego.ModeHTML
|
||||
tgMsg.MessageThreadID = threadID
|
||||
type sendChunkParams struct {
|
||||
chatID int64
|
||||
threadID int
|
||||
content string
|
||||
replyToID string
|
||||
mdFallback string
|
||||
useMarkdownV2 bool
|
||||
}
|
||||
|
||||
if replyToID != "" {
|
||||
if mid, parseErr := strconv.Atoi(replyToID); parseErr == nil {
|
||||
// sendChunk sends a single HTML/MarkdownV2 message, falling back to the original
|
||||
// markdown as plain text on parse failure so users never see raw HTML/MarkdownV2 tags.
|
||||
func (c *TelegramChannel) sendChunk(
|
||||
ctx context.Context,
|
||||
params sendChunkParams,
|
||||
) error {
|
||||
tgMsg := tu.Message(tu.ID(params.chatID), params.content)
|
||||
tgMsg.MessageThreadID = params.threadID
|
||||
if params.useMarkdownV2 {
|
||||
tgMsg.WithParseMode(telego.ModeMarkdownV2)
|
||||
} else {
|
||||
tgMsg.WithParseMode(telego.ModeHTML)
|
||||
}
|
||||
|
||||
if params.replyToID != "" {
|
||||
if mid, parseErr := strconv.Atoi(params.replyToID); parseErr == nil {
|
||||
tgMsg.ReplyParameters = &telego.ReplyParameters{
|
||||
MessageID: mid,
|
||||
}
|
||||
@@ -230,15 +290,15 @@ func (c *TelegramChannel) sendHTMLChunk(
|
||||
}
|
||||
|
||||
if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
tgMsg.Text = mdFallback
|
||||
logParseFailed(err, params.useMarkdownV2)
|
||||
|
||||
tgMsg.Text = params.mdFallback
|
||||
tgMsg.ParseMode = ""
|
||||
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
return fmt.Errorf("telegram send: %w", channels.ErrTemporary)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -279,6 +339,7 @@ func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(
|
||||
|
||||
// EditMessage implements channels.MessageEditor.
|
||||
func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
|
||||
useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2
|
||||
cid, _, err := parseTelegramChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -287,10 +348,19 @@ func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messag
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
htmlContent := markdownToTelegramHTML(content)
|
||||
editMsg := tu.EditMessageText(tu.ID(cid), mid, htmlContent)
|
||||
editMsg.ParseMode = telego.ModeHTML
|
||||
parsedContent := parseContent(content, useMarkdownV2)
|
||||
editMsg := tu.EditMessageText(tu.ID(cid), mid, parsedContent)
|
||||
if useMarkdownV2 {
|
||||
editMsg.WithParseMode(telego.ModeMarkdownV2)
|
||||
} else {
|
||||
editMsg.WithParseMode(telego.ModeHTML)
|
||||
}
|
||||
_, err = c.bot.EditMessageText(ctx, editMsg)
|
||||
if err != nil {
|
||||
logParseFailed(err, useMarkdownV2)
|
||||
_, err = c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(cid), mid, content))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -367,6 +437,20 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe
|
||||
Caption: part.Caption,
|
||||
}
|
||||
_, err = c.bot.SendPhoto(ctx, params)
|
||||
if err != nil && strings.Contains(err.Error(), "PHOTO_INVALID_DIMENSIONS") {
|
||||
if _, seekErr := file.Seek(0, io.SeekStart); seekErr != nil {
|
||||
file.Close()
|
||||
return fmt.Errorf("telegram rewind media after photo failure: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
docParams := &telego.SendDocumentParams{
|
||||
ChatID: tu.ID(chatID),
|
||||
MessageThreadID: threadID,
|
||||
Document: telego.InputFile{File: file},
|
||||
Caption: part.Caption,
|
||||
}
|
||||
_, err = c.bot.SendDocument(ctx, docParams)
|
||||
}
|
||||
case "audio":
|
||||
params := &telego.SendAudioParams{
|
||||
ChatID: tu.ID(chatID),
|
||||
@@ -624,6 +708,14 @@ func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string)
|
||||
return c.downloadFileWithInfo(file, ext)
|
||||
}
|
||||
|
||||
func parseContent(text string, useMarkdownV2 bool) string {
|
||||
if useMarkdownV2 {
|
||||
return markdownToTelegramMarkdownV2(text)
|
||||
}
|
||||
|
||||
return markdownToTelegramHTML(text)
|
||||
}
|
||||
|
||||
// parseTelegramChatID splits "chatID/threadID" into its components.
|
||||
// Returns threadID=0 when no "/" is present (non-forum messages).
|
||||
func parseTelegramChatID(chatID string) (int64, int, error) {
|
||||
@@ -643,109 +735,18 @@ func parseTelegramChatID(chatID string) (int64, int, error) {
|
||||
return cid, tid, nil
|
||||
}
|
||||
|
||||
func markdownToTelegramHTML(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
func logParseFailed(err error, useMarkdownV2 bool) {
|
||||
parsingName := "HTML"
|
||||
if useMarkdownV2 {
|
||||
parsingName = "MarkdownV2"
|
||||
}
|
||||
|
||||
codeBlocks := extractCodeBlocks(text)
|
||||
text = codeBlocks.text
|
||||
|
||||
inlineCodes := extractInlineCodes(text)
|
||||
text = inlineCodes.text
|
||||
|
||||
text = reHeading.ReplaceAllString(text, "$1")
|
||||
|
||||
text = reBlockquote.ReplaceAllString(text, "$1")
|
||||
|
||||
text = escapeHTML(text)
|
||||
|
||||
text = reLink.ReplaceAllString(text, `<a href="$2">$1</a>`)
|
||||
|
||||
text = reBoldStar.ReplaceAllString(text, "<b>$1</b>")
|
||||
|
||||
text = reBoldUnder.ReplaceAllString(text, "<b>$1</b>")
|
||||
|
||||
text = reItalic.ReplaceAllStringFunc(text, func(s string) string {
|
||||
match := reItalic.FindStringSubmatch(s)
|
||||
if len(match) < 2 {
|
||||
return s
|
||||
}
|
||||
return "<i>" + match[1] + "</i>"
|
||||
})
|
||||
|
||||
text = reStrike.ReplaceAllString(text, "<s>$1</s>")
|
||||
|
||||
text = reListItem.ReplaceAllString(text, "• ")
|
||||
|
||||
for i, code := range inlineCodes.codes {
|
||||
escaped := escapeHTML(code)
|
||||
text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("<code>%s</code>", escaped))
|
||||
}
|
||||
|
||||
for i, code := range codeBlocks.codes {
|
||||
escaped := escapeHTML(code)
|
||||
text = strings.ReplaceAll(
|
||||
text,
|
||||
fmt.Sprintf("\x00CB%d\x00", i),
|
||||
fmt.Sprintf("<pre><code>%s</code></pre>", escaped),
|
||||
)
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
type codeBlockMatch struct {
|
||||
text string
|
||||
codes []string
|
||||
}
|
||||
|
||||
func extractCodeBlocks(text string) codeBlockMatch {
|
||||
matches := reCodeBlock.FindAllStringSubmatch(text, -1)
|
||||
|
||||
codes := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = reCodeBlock.ReplaceAllStringFunc(text, func(m string) string {
|
||||
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return codeBlockMatch{text: text, codes: codes}
|
||||
}
|
||||
|
||||
type inlineCodeMatch struct {
|
||||
text string
|
||||
codes []string
|
||||
}
|
||||
|
||||
func extractInlineCodes(text string) inlineCodeMatch {
|
||||
matches := reInlineCode.FindAllStringSubmatch(text, -1)
|
||||
|
||||
codes := make([]string, 0, len(matches))
|
||||
for _, match := range matches {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = reInlineCode.ReplaceAllStringFunc(text, func(m string) string {
|
||||
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return inlineCodeMatch{text: text, codes: codes}
|
||||
}
|
||||
|
||||
func escapeHTML(text string) string {
|
||||
text = strings.ReplaceAll(text, "&", "&")
|
||||
text = strings.ReplaceAll(text, "<", "<")
|
||||
text = strings.ReplaceAll(text, ">", ">")
|
||||
return text
|
||||
logger.ErrorCF("telegram",
|
||||
fmt.Sprintf("%s parse failed, falling back to plain text", parsingName),
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message via entities.
|
||||
|
||||
@@ -3,7 +3,6 @@ package telegram
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
|
||||
@@ -36,10 +35,7 @@ func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
t.Fatalf("handleMessage error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
|
||||
@@ -108,22 +108,24 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
|
||||
t.Fatalf("handleMessage error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if tc.wantForwarded {
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if tc.wantForwarded {
|
||||
t.Fatal("timeout waiting for message to be forwarded")
|
||||
return
|
||||
}
|
||||
if inbound.Content != tc.wantContent {
|
||||
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
|
||||
case inbound, ok := <-messageBus.InboundChan():
|
||||
if tc.wantForwarded {
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Content != tc.wantContent {
|
||||
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ok {
|
||||
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
ta "github.com/mymmrac/telego/telegoapi"
|
||||
@@ -15,6 +17,8 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
const testToken = "1234567890:aaaabbbbaaaabbbbaaaabbbbaaaabbbbccc"
|
||||
@@ -38,8 +42,20 @@ func (s *stubCaller) Call(ctx context.Context, url string, data *ta.RequestData)
|
||||
// stubConstructor implements ta.RequestConstructor for testing.
|
||||
type stubConstructor struct{}
|
||||
|
||||
type multipartCall struct {
|
||||
Parameters map[string]string
|
||||
FileSizes map[string]int
|
||||
}
|
||||
|
||||
func (s *stubConstructor) JSONRequest(parameters any) (*ta.RequestData, error) {
|
||||
return &ta.RequestData{}, nil
|
||||
b, err := json.Marshal(parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ta.RequestData{
|
||||
ContentType: "application/json",
|
||||
BodyRaw: b,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubConstructor) MultipartRequest(
|
||||
@@ -49,6 +65,36 @@ func (s *stubConstructor) MultipartRequest(
|
||||
return &ta.RequestData{}, nil
|
||||
}
|
||||
|
||||
type multipartRecordingConstructor struct {
|
||||
stubConstructor
|
||||
calls []multipartCall
|
||||
}
|
||||
|
||||
func (s *multipartRecordingConstructor) MultipartRequest(
|
||||
parameters map[string]string,
|
||||
files map[string]ta.NamedReader,
|
||||
) (*ta.RequestData, error) {
|
||||
call := multipartCall{
|
||||
Parameters: make(map[string]string, len(parameters)),
|
||||
FileSizes: make(map[string]int, len(files)),
|
||||
}
|
||||
for k, v := range parameters {
|
||||
call.Parameters[k] = v
|
||||
}
|
||||
for field, file := range files {
|
||||
if file == nil {
|
||||
continue
|
||||
}
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
call.FileSizes[field] = len(data)
|
||||
}
|
||||
s.calls = append(s.calls, call)
|
||||
return &ta.RequestData{}, nil
|
||||
}
|
||||
|
||||
// successResponse returns a ta.Response that telego will treat as a successful SendMessage.
|
||||
func successResponse(t *testing.T) *ta.Response {
|
||||
t.Helper()
|
||||
@@ -60,11 +106,19 @@ func successResponse(t *testing.T) *ta.Response {
|
||||
|
||||
// newTestChannel creates a TelegramChannel with a mocked bot for unit testing.
|
||||
func newTestChannel(t *testing.T, caller *stubCaller) *TelegramChannel {
|
||||
return newTestChannelWithConstructor(t, caller, &stubConstructor{})
|
||||
}
|
||||
|
||||
func newTestChannelWithConstructor(
|
||||
t *testing.T,
|
||||
caller *stubCaller,
|
||||
constructor ta.RequestConstructor,
|
||||
) *TelegramChannel {
|
||||
t.Helper()
|
||||
|
||||
bot, err := telego.NewBot(testToken,
|
||||
telego.WithAPICaller(caller),
|
||||
telego.WithRequestConstructor(&stubConstructor{}),
|
||||
telego.WithRequestConstructor(constructor),
|
||||
telego.WithDiscardLogger(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -78,9 +132,96 @@ func newTestChannel(t *testing.T, caller *stubCaller) *TelegramChannel {
|
||||
BaseChannel: base,
|
||||
bot: bot,
|
||||
chatIDs: make(map[string]int64),
|
||||
config: config.DefaultConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_ImageFallbacksToDocumentOnInvalidDimensions(t *testing.T) {
|
||||
constructor := &multipartRecordingConstructor{}
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
switch {
|
||||
case strings.Contains(url, "sendPhoto"):
|
||||
return nil, errors.New(`api: 400 "Bad Request: PHOTO_INVALID_DIMENSIONS"`)
|
||||
case strings.Contains(url, "sendDocument"):
|
||||
return successResponse(t), nil
|
||||
default:
|
||||
t.Fatalf("unexpected API call: %s", url)
|
||||
return nil, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
ch := newTestChannelWithConstructor(t, caller, constructor)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
localPath := filepath.Join(tmpDir, "woodstock-en-10s.png")
|
||||
content := []byte("fake-png-content")
|
||||
require.NoError(t, os.WriteFile(localPath, content, 0o644))
|
||||
|
||||
ref, err := store.Store(
|
||||
localPath,
|
||||
media.MediaMeta{Filename: "woodstock-en-10s.png", ContentType: "image/png"},
|
||||
"scope-1",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "12345",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "image",
|
||||
Ref: ref,
|
||||
Caption: "caption",
|
||||
}},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, caller.calls, 2)
|
||||
assert.Contains(t, caller.calls[0].URL, "sendPhoto")
|
||||
assert.Contains(t, caller.calls[1].URL, "sendDocument")
|
||||
require.Len(t, constructor.calls, 2)
|
||||
assert.Equal(t, len(content), constructor.calls[0].FileSizes["photo"])
|
||||
assert.Equal(t, len(content), constructor.calls[1].FileSizes["document"])
|
||||
assert.Equal(t, "caption", constructor.calls[1].Parameters["caption"])
|
||||
}
|
||||
|
||||
func TestSendMedia_ImageNonDimensionErrorDoesNotFallback(t *testing.T) {
|
||||
constructor := &multipartRecordingConstructor{}
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return nil, errors.New("api: 500 \"server exploded\"")
|
||||
},
|
||||
}
|
||||
ch := newTestChannelWithConstructor(t, caller, constructor)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
ch.SetMediaStore(store)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
localPath := filepath.Join(tmpDir, "image.png")
|
||||
require.NoError(t, os.WriteFile(localPath, []byte("fake-png-content"), 0o644))
|
||||
|
||||
ref, err := store.Store(localPath, media.MediaMeta{Filename: "image.png", ContentType: "image/png"}, "scope-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ch.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
ChatID: "12345",
|
||||
Parts: []bus.MediaPart{{
|
||||
Type: "image",
|
||||
Ref: ref,
|
||||
}},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, channels.ErrTemporary)
|
||||
require.Len(t, caller.calls, 1)
|
||||
assert.Contains(t, caller.calls[0].URL, "sendPhoto")
|
||||
require.Len(t, constructor.calls, 1)
|
||||
assert.NotContains(t, caller.calls[0].URL, "sendDocument")
|
||||
}
|
||||
|
||||
func TestSend_EmptyContent(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
@@ -235,6 +376,55 @@ func TestSend_MarkdownShortButHTMLLong_MultipleCalls(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
func TestSend_HTMLOverflow_WordBoundary(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
// We want to force a split near index ~2600 while keeping markdown length <= 4000.
|
||||
// Prefix of 430 bold units (6 chars each) = 2580 chars.
|
||||
// Expansion per unit is +3 chars when converted to HTML, so 2580 + 430*3 = 3870.
|
||||
prefix := strings.Repeat("**a** ", 430)
|
||||
targetWord := "TARGETWORDTHATSTAYSTOGETHER"
|
||||
// Suffix of 230 bold units (6 chars each) = 1380 chars.
|
||||
// Total markdown length: 2580 (prefix) + 27 (target word) + 1380 (suffix) = 3987 <= 4000.
|
||||
// HTML expansion adds ~3 chars per bold unit: (430 + 230)*3 = 1980 extra chars,
|
||||
// so total HTML length comfortably exceeds 4096.
|
||||
suffix := strings.Repeat(" **b**", 230)
|
||||
content := prefix + targetWord + suffix
|
||||
|
||||
// Ensure the test content matches the intended boundary conditions.
|
||||
assert.LessOrEqual(t, len([]rune(content)), 4000, "markdown content must not exceed chunk size for this test")
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "123456",
|
||||
Content: content,
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
foundFullWord := false
|
||||
for i, call := range caller.calls {
|
||||
var params map[string]any
|
||||
err := json.Unmarshal(call.Data.BodyRaw, ¶ms)
|
||||
require.NoError(t, err)
|
||||
text, _ := params["text"].(string)
|
||||
|
||||
hasWord := strings.Contains(text, targetWord)
|
||||
t.Logf("Chunk %d length: %d, contains target word: %v", i, len(text), hasWord)
|
||||
|
||||
if hasWord {
|
||||
foundFullWord = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundFullWord, "The target word should not be split between chunks")
|
||||
}
|
||||
|
||||
func TestSend_NotRunning(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
@@ -355,10 +545,7 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
|
||||
err := ch.handleMessage(context.Background(), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok, "expected inbound message")
|
||||
|
||||
// Composite chatID should include thread ID
|
||||
@@ -397,10 +584,7 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
|
||||
err := ch.handleMessage(context.Background(), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok)
|
||||
|
||||
// Plain chatID without thread suffix
|
||||
@@ -443,10 +627,7 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
|
||||
err := ch.handleMessage(context.Background(), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
require.True(t, ok)
|
||||
|
||||
// chatID should NOT include thread suffix for non-forum groups
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
*bold \*text*
|
||||
_italic \*text_
|
||||
__underline__
|
||||
~strikethrough~
|
||||
||spoiler||
|
||||
*bold _italic bold ~italic bold strikethrough ||italic bold strikethrough spoiler||~ __underline italic bold___ bold*
|
||||
[inline URL](http://www.example.com/)
|
||||
[inline mention of a user](tg://user?id=123456789)
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
`inline fixed-width code`
|
||||
```
|
||||
pre-formatted fixed-width code block
|
||||
```
|
||||
```python
|
||||
pre-formatted fixed-width code block written in the Python programming language
|
||||
```
|
||||
>Block quotation started
|
||||
>Block quotation continued
|
||||
>Block quotation continued
|
||||
>Block quotation continued
|
||||
>The last line of the block quotation
|
||||
**>The expandable block quotation started right after the previous block quotation
|
||||
>It is separated from the previous block quotation by an empty bold entity
|
||||
>Expandable block quotation continued
|
||||
>Hidden by default part of the expandable block quotation started
|
||||
>Expandable block quotation continued
|
||||
>The last line of the expandable block quotation with the expandability mark||
|
||||
@@ -3,7 +3,6 @@ package whatsapp
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -25,10 +24,7 @@ func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T
|
||||
"content": "/help",
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
inbound, ok := <-messageBus.InboundChan()
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
|
||||
@@ -43,14 +43,19 @@ func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Channel != "whatsapp_native" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.Content != "/new" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for message to be forwarded")
|
||||
return
|
||||
case inbound, ok := <-messageBus.InboundChan():
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Channel != "whatsapp_native" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.Content != "/new" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+18
-5
@@ -312,6 +312,7 @@ type TelegramConfig struct {
|
||||
Typing TypingConfig `json:"typing,omitempty"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_TELEGRAM_REASONING_CHANNEL_ID"`
|
||||
UseMarkdownV2 bool `json:"use_markdown_v2" env:"PICOCLAW_CHANNELS_TELEGRAM_USE_MARKDOWN_V2"`
|
||||
}
|
||||
|
||||
type FeishuConfig struct {
|
||||
@@ -532,6 +533,7 @@ type ProvidersConfig struct {
|
||||
Minimax ProviderConfig `json:"minimax"`
|
||||
LongCat ProviderConfig `json:"longcat"`
|
||||
ModelScope ProviderConfig `json:"modelscope"`
|
||||
Novita ProviderConfig `json:"novita"`
|
||||
}
|
||||
|
||||
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
|
||||
@@ -560,7 +562,8 @@ func (p ProvidersConfig) IsEmpty() bool {
|
||||
p.Avian.APIKey == "" && p.Avian.APIBase == "" &&
|
||||
p.Minimax.APIKey == "" && p.Minimax.APIBase == "" &&
|
||||
p.LongCat.APIKey == "" && p.LongCat.APIBase == "" &&
|
||||
p.ModelScope.APIKey == "" && p.ModelScope.APIBase == ""
|
||||
p.ModelScope.APIKey == "" && p.ModelScope.APIBase == "" &&
|
||||
p.Novita.APIKey == "" && p.Novita.APIBase == ""
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
|
||||
@@ -590,7 +593,9 @@ type OpenAIProviderConfig struct {
|
||||
// ModelConfig represents a model-centric provider configuration.
|
||||
// It allows adding new providers (especially OpenAI-compatible ones) via configuration only.
|
||||
// The model field uses protocol prefix format: [protocol/]model-identifier
|
||||
// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Supported protocols include openai, anthropic, antigravity, claude-cli,
|
||||
// codex-cli, github-copilot, and named OpenAI-compatible protocols such as
|
||||
// groq, deepseek, modelscope, and novita.
|
||||
// Default protocol is "openai" if no prefix is specified.
|
||||
type ModelConfig struct {
|
||||
// Required fields
|
||||
@@ -694,10 +699,18 @@ type WebToolsConfig struct {
|
||||
Perplexity PerplexityConfig ` json:"perplexity"`
|
||||
SearXNG SearXNGConfig ` json:"searxng"`
|
||||
GLMSearch GLMSearchConfig ` json:"glm_search"`
|
||||
// PreferNative controls whether to use provider-native web search when
|
||||
// the active LLM supports it (e.g. OpenAI web_search_preview). When true,
|
||||
// the client-side web_search tool is hidden to avoid duplicate search surfaces,
|
||||
// and the provider's built-in search is used instead. Falls back to client-side
|
||||
// search when the provider does not support native search.
|
||||
PreferNative bool `json:"prefer_native" env:"PICOCLAW_TOOLS_WEB_PREFER_NATIVE"`
|
||||
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
|
||||
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
|
||||
Format string `json:"format,omitempty" env:"PICOCLAW_TOOLS_WEB_FORMAT"`
|
||||
PrivateHostWhitelist FlexibleStringSlice `json:"private_host_whitelist,omitempty" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"`
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
@@ -1030,7 +1043,7 @@ func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) {
|
||||
}
|
||||
|
||||
// Multiple configs - use round-robin for load balancing
|
||||
idx := rrCounter.Add(1) % uint64(len(matches))
|
||||
idx := (rrCounter.Add(1) - 1) % uint64(len(matches))
|
||||
return &matches[idx], nil
|
||||
}
|
||||
|
||||
|
||||
@@ -77,6 +77,22 @@ func TestAgentModelConfig_MarshalObject(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvidersConfig_IsEmpty(t *testing.T) {
|
||||
var empty ProvidersConfig
|
||||
if !empty.IsEmpty() {
|
||||
t.Fatal("empty ProvidersConfig should report empty")
|
||||
}
|
||||
|
||||
novita := ProvidersConfig{
|
||||
Novita: ProviderConfig{
|
||||
APIKey: "test-key",
|
||||
},
|
||||
}
|
||||
if novita.IsEmpty() {
|
||||
t.Fatal("ProvidersConfig with novita settings should not report empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentConfig_FullParse(t *testing.T) {
|
||||
jsonData := `{
|
||||
"agents": {
|
||||
@@ -401,6 +417,45 @@ func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WebPreferNativeEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Tools.Web.PreferNative {
|
||||
t.Fatal("DefaultConfig().Tools.Web.PreferNative should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_WebPreferNativeDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"tools":{"web":{"enabled":true}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if !cfg.Tools.Web.PreferNative {
|
||||
t.Fatal("PreferNative should remain true when unset in config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_WebPreferNativeCanBeDisabled(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"tools":{"web":{"prefer_native":false}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if cfg.Tools.Web.PreferNative {
|
||||
t.Fatal("PreferNative should be false when disabled in config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Tools.Exec.AllowRemote {
|
||||
|
||||
@@ -15,7 +15,7 @@ func DefaultConfig() *Config {
|
||||
// Determine the base path for the workspace.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
var homePath string
|
||||
if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" {
|
||||
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else {
|
||||
userHome, _ := os.UserHomeDir()
|
||||
@@ -59,6 +59,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
Text: "Thinking... 💭",
|
||||
},
|
||||
UseMarkdownV2: false,
|
||||
},
|
||||
Feishu: FeishuConfig{
|
||||
Enabled: false,
|
||||
@@ -412,8 +413,10 @@ func DefaultConfig() *Config {
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
PreferNative: true,
|
||||
Proxy: "",
|
||||
FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
|
||||
Format: "plaintext",
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package config
|
||||
|
||||
// Runtime environment variable keys for the picoclaw process.
|
||||
// These control the location of files and binaries at runtime and are read
|
||||
// directly via os.Getenv / os.LookupEnv. All picoclaw-specific keys use the
|
||||
// PICOCLAW_ prefix. Reference these constants instead of inline string
|
||||
// literals to keep all supported knobs visible in one place and to prevent
|
||||
// typos.
|
||||
const (
|
||||
// EnvHome overrides the base directory for all picoclaw data
|
||||
// (config, workspace, skills, auth store, …).
|
||||
// Default: ~/.picoclaw
|
||||
EnvHome = "PICOCLAW_HOME"
|
||||
|
||||
// EnvConfig overrides the full path to the JSON config file.
|
||||
// Default: $PICOCLAW_HOME/config.json
|
||||
EnvConfig = "PICOCLAW_CONFIG"
|
||||
|
||||
// EnvBuiltinSkills overrides the directory from which built-in
|
||||
// skills are loaded.
|
||||
// Default: <cwd>/skills
|
||||
EnvBuiltinSkills = "PICOCLAW_BUILTIN_SKILLS"
|
||||
|
||||
// EnvBinary overrides the path to the picoclaw executable.
|
||||
// Used by the web launcher when spawning the gateway subprocess.
|
||||
// Default: resolved from the same directory as the current executable.
|
||||
EnvBinary = "PICOCLAW_BINARY"
|
||||
|
||||
// EnvGatewayHost overrides the host address for the gateway server.
|
||||
// Default: "127.0.0.1"
|
||||
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
|
||||
)
|
||||
@@ -80,6 +80,36 @@ func TestGetModelConfig_RoundRobin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_RoundRobinStartsFromFirstMatch(t *testing.T) {
|
||||
rrCounter.Store(0)
|
||||
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"},
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-2", APIKey: "key2"},
|
||||
{ModelName: "lb-model", Model: "openai/gpt-4o-3", APIKey: "key3"},
|
||||
},
|
||||
}
|
||||
|
||||
wantOrder := []string{
|
||||
"openai/gpt-4o-1",
|
||||
"openai/gpt-4o-2",
|
||||
"openai/gpt-4o-3",
|
||||
"openai/gpt-4o-1",
|
||||
"openai/gpt-4o-2",
|
||||
}
|
||||
|
||||
for i, want := range wantOrder {
|
||||
result, err := cfg.GetModelConfig("lb-model")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelConfig() call %d error = %v", i, err)
|
||||
}
|
||||
if result.Model != want {
|
||||
t.Fatalf("GetModelConfig() call %d model = %q, want %q", i, result.Model, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfig_Concurrent(t *testing.T) {
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
|
||||
@@ -66,6 +66,14 @@ var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
|
||||
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
|
||||
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
|
||||
|
||||
// SSHKeyPathEnvVar is the environment variable that specifies the path to the
|
||||
// SSH private key used for enc:// credential encryption and decryption.
|
||||
const SSHKeyPathEnvVar = "PICOCLAW_SSH_KEY_PATH"
|
||||
|
||||
// picoclawHome is a package-local copy of config.EnvHome. It is kept here to
|
||||
// avoid a circular import between pkg/credential and pkg/config.
|
||||
const picoclawHome = "PICOCLAW_HOME"
|
||||
|
||||
const (
|
||||
fileScheme = "file://"
|
||||
encScheme = "enc://"
|
||||
@@ -73,7 +81,6 @@ const (
|
||||
saltLen = 16
|
||||
nonceLen = 12
|
||||
keyLen = 32
|
||||
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
|
||||
)
|
||||
|
||||
// Resolver resolves raw credential strings for model_list api_key fields.
|
||||
@@ -248,14 +255,14 @@ func allowedSSHKeyPath(path string) bool {
|
||||
clean := filepath.Clean(path)
|
||||
|
||||
// Exact match with PICOCLAW_SSH_KEY_PATH.
|
||||
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
|
||||
if envPath, ok := os.LookupEnv(SSHKeyPathEnvVar); ok && envPath != "" {
|
||||
if clean == filepath.Clean(envPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within PICOCLAW_HOME.
|
||||
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
|
||||
if picoHome := os.Getenv(picoclawHome); picoHome != "" {
|
||||
if isWithinDir(clean, picoHome) {
|
||||
return true
|
||||
}
|
||||
@@ -316,7 +323,7 @@ func pickSSHKeyPath(override string) string {
|
||||
if override != "" {
|
||||
return override
|
||||
}
|
||||
if p, ok := os.LookupEnv(sshKeyEnv); ok {
|
||||
if p, ok := os.LookupEnv(SSHKeyPathEnvVar); ok {
|
||||
return p // respect explicit setting, even if ""
|
||||
}
|
||||
return findDefaultSSHKey()
|
||||
|
||||
+66
-11
@@ -65,6 +65,7 @@ type CronService struct {
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
wakeChan chan struct{}
|
||||
gronx *gronx.Gronx
|
||||
}
|
||||
|
||||
@@ -73,6 +74,7 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
|
||||
storePath: storePath,
|
||||
onJob: onJob,
|
||||
gronx: gronx.New(),
|
||||
wakeChan: make(chan struct{}),
|
||||
}
|
||||
// Initialize and load store on creation
|
||||
cs.loadStore()
|
||||
@@ -97,6 +99,9 @@ func (cs *CronService) Start() error {
|
||||
}
|
||||
|
||||
cs.stopChan = make(chan struct{})
|
||||
if cs.wakeChan == nil {
|
||||
cs.wakeChan = make(chan struct{})
|
||||
}
|
||||
cs.running = true
|
||||
go cs.runLoop(cs.stopChan)
|
||||
|
||||
@@ -119,14 +124,47 @@ func (cs *CronService) Stop() {
|
||||
}
|
||||
|
||||
func (cs *CronService) runLoop(stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
timer := time.NewTimer(time.Hour)
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
// every loop, recalculate the next wake time
|
||||
cs.mu.RLock()
|
||||
nextWake := cs.getNextWakeMS()
|
||||
cs.mu.RUnlock()
|
||||
|
||||
var delay time.Duration
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
if nextWake == nil {
|
||||
// no jobs, sleep for a long time (or until a new job is added)
|
||||
delay = time.Hour
|
||||
} else {
|
||||
diff := *nextWake - now
|
||||
if diff <= 0 {
|
||||
delay = 0
|
||||
} else {
|
||||
delay = time.Duration(diff) * time.Millisecond
|
||||
}
|
||||
}
|
||||
|
||||
timer.Reset(delay)
|
||||
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
case <-cs.wakeChan: // wake on new job or update
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
continue
|
||||
case <-timer.C:
|
||||
cs.checkJobs()
|
||||
}
|
||||
}
|
||||
@@ -264,22 +302,19 @@ func (cs *CronService) executeJobByID(jobID string) {
|
||||
}
|
||||
|
||||
func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int64 {
|
||||
if schedule.Kind == "at" {
|
||||
switch schedule.Kind {
|
||||
case "at":
|
||||
if schedule.AtMS != nil && *schedule.AtMS > nowMS {
|
||||
return schedule.AtMS
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if schedule.Kind == "every" {
|
||||
case "every":
|
||||
if schedule.EveryMS == nil || *schedule.EveryMS <= 0 {
|
||||
return nil
|
||||
}
|
||||
next := nowMS + *schedule.EveryMS
|
||||
return &next
|
||||
}
|
||||
|
||||
if schedule.Kind == "cron" {
|
||||
case "cron":
|
||||
if schedule.Expr == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -294,9 +329,19 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6
|
||||
|
||||
nextMS := nextTime.UnixMilli()
|
||||
return &nextMS
|
||||
default:
|
||||
log.Printf("[cron] unknown schedule kind '%s'", schedule.Kind)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// wake up the loop to re-evaluate next wake time immediately (e.g. after add/update/remove jobs)
|
||||
func (cs *CronService) notify() {
|
||||
select {
|
||||
case cs.wakeChan <- struct{}{}:
|
||||
default:
|
||||
// if the channel is full, it means the loop will wake up soon anyway, so we can skip sending
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CronService) recomputeNextRuns() {
|
||||
@@ -400,6 +445,8 @@ func (cs *CronService) AddJob(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cs.notify()
|
||||
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
@@ -411,6 +458,9 @@ func (cs *CronService) UpdateJob(job *CronJob) error {
|
||||
if cs.store.Jobs[i].ID == job.ID {
|
||||
cs.store.Jobs[i] = *job
|
||||
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
|
||||
|
||||
cs.notify()
|
||||
|
||||
return cs.saveStoreUnsafe()
|
||||
}
|
||||
}
|
||||
@@ -441,6 +491,8 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
cs.notify()
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
@@ -463,6 +515,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob {
|
||||
if err := cs.saveStoreUnsafe(); err != nil {
|
||||
log.Printf("[cron] failed to save store after enable: %v", err)
|
||||
}
|
||||
|
||||
cs.notify()
|
||||
|
||||
return job
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSaveStore_FilePermissions(t *testing.T) {
|
||||
@@ -36,3 +39,199 @@ func TestSaveStore_FilePermissions(t *testing.T) {
|
||||
func int64Ptr(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func setupService(handler JobHandler) (*CronService, string) {
|
||||
tmpFile := fmt.Sprintf("test_cron_%d.json", time.Now().UnixNano())
|
||||
cs := NewCronService(tmpFile, handler)
|
||||
return cs, tmpFile
|
||||
}
|
||||
|
||||
func TestCronService_CRUD(t *testing.T) {
|
||||
cs, path := setupService(nil)
|
||||
defer os.Remove(path)
|
||||
|
||||
// Test AddJob
|
||||
at := time.Now().Add(time.Hour).UnixMilli()
|
||||
job, err := cs.AddJob("Task1", CronSchedule{Kind: "at", AtMS: &at}, "msg", true, "ch", "to")
|
||||
if err != nil || job.ID == "" {
|
||||
t.Fatalf("AddJob failed: %v", err)
|
||||
}
|
||||
|
||||
// Test ListJobs
|
||||
if len(cs.ListJobs(true)) != 1 {
|
||||
t.Error("ListJobs should return 1 job")
|
||||
}
|
||||
|
||||
// Test UpdateJob
|
||||
job.Name = "UpdatedName"
|
||||
err = cs.UpdateJob(job)
|
||||
if err != nil || cs.store.Jobs[0].Name != "UpdatedName" {
|
||||
t.Error("UpdateJob failed")
|
||||
}
|
||||
|
||||
// Test EnableJob
|
||||
cs.EnableJob(job.ID, false)
|
||||
if cs.store.Jobs[0].Enabled != false || cs.store.Jobs[0].State.NextRunAtMS != nil {
|
||||
t.Error("EnableJob(false) failed to clear state")
|
||||
}
|
||||
|
||||
// Test RemoveJob
|
||||
removed := cs.RemoveJob(job.ID)
|
||||
if !removed || len(cs.store.Jobs) != 0 {
|
||||
t.Error("RemoveJob failed")
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Test Cron Expression Calculation Logic
|
||||
func TestCronService_ComputeNextRun(t *testing.T) {
|
||||
cs, path := setupService(nil)
|
||||
defer os.Remove(path)
|
||||
|
||||
now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC).UnixMilli()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
schedule CronSchedule
|
||||
wantNil bool
|
||||
}{
|
||||
{"Valid Cron", CronSchedule{Kind: "cron", Expr: "0 * * * *"}, false},
|
||||
{"Invalid Cron", CronSchedule{Kind: "cron", Expr: "invalid"}, true},
|
||||
{"Every MS", CronSchedule{Kind: "every", EveryMS: int64Ptr(5000)}, false},
|
||||
{"At Future", CronSchedule{Kind: "at", AtMS: int64Ptr(now + 1000)}, false},
|
||||
{"At Past", CronSchedule{Kind: "at", AtMS: int64Ptr(now - 1000)}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := cs.computeNextRun(&tt.schedule, now)
|
||||
if (got == nil) != tt.wantNil {
|
||||
t.Errorf("%s: got %v, wantNil %v", tt.name, got, tt.wantNil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Test Execution Flow
|
||||
func TestCronService_ExecutionFlow(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
executedJobs := make(map[string]bool)
|
||||
|
||||
handler := func(job *CronJob) (string, error) {
|
||||
mu.Lock()
|
||||
executedJobs[job.ID] = true
|
||||
mu.Unlock()
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
cs, path := setupService(handler)
|
||||
defer os.Remove(path)
|
||||
|
||||
// Start the service
|
||||
if err := cs.Start(); err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
defer cs.Stop()
|
||||
|
||||
// Add a job then runs 100ms from now
|
||||
target := time.Now().Add(100 * time.Millisecond).UnixMilli()
|
||||
job, _ := cs.AddJob("FastJob", CronSchedule{Kind: "at", AtMS: &target}, "", false, "", "")
|
||||
|
||||
// Check for job execution with a timeout
|
||||
success := false
|
||||
for range 20 {
|
||||
mu.Lock()
|
||||
if executedJobs[job.ID] {
|
||||
success = true
|
||||
mu.Unlock()
|
||||
break
|
||||
}
|
||||
mu.Unlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
if !success {
|
||||
t.Error("Job was not executed in time")
|
||||
}
|
||||
|
||||
// check that the job is removed after execution (DeleteAfterRun = true)
|
||||
status := cs.Status()
|
||||
if status["jobs"].(int) != 0 {
|
||||
t.Errorf("Job should be deleted after run, got count: %v", status["jobs"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronService_PersistenceIntegrity(t *testing.T) {
|
||||
tmpFile := "persist_test.json"
|
||||
defer os.Remove(tmpFile)
|
||||
|
||||
// write a job and persist
|
||||
cs1 := NewCronService(tmpFile, nil)
|
||||
at := int64(2000000000000)
|
||||
cs1.AddJob("PersistMe", CronSchedule{Kind: "at", AtMS: &at}, "payload", true, "ch1", "")
|
||||
|
||||
// check file exists
|
||||
if _, err := os.Stat(tmpFile); os.IsNotExist(err) {
|
||||
t.Fatal("Store file was not created")
|
||||
}
|
||||
|
||||
// reload and check data integrity
|
||||
cs2 := NewCronService(tmpFile, nil)
|
||||
if err := cs2.Load(); err != nil {
|
||||
t.Fatalf("Failed to load store: %v", err)
|
||||
}
|
||||
|
||||
jobs := cs2.ListJobs(true)
|
||||
if len(jobs) != 1 || jobs[0].Name != "PersistMe" {
|
||||
t.Errorf("Data corruption after reload. Got: %+v", jobs)
|
||||
}
|
||||
|
||||
// test loading invalid JSON
|
||||
os.WriteFile(tmpFile, []byte("{invalid json}"), 0o644)
|
||||
cs3 := NewCronService(tmpFile, nil)
|
||||
err := cs3.loadStore()
|
||||
if err == nil {
|
||||
t.Error("Should return error when loading invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronService_ConcurrentAccess(t *testing.T) {
|
||||
cs, path := setupService(nil)
|
||||
defer os.Remove(path)
|
||||
|
||||
cs.Start()
|
||||
defer cs.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
workers := 10
|
||||
iterations := 50
|
||||
|
||||
wg.Add(workers * 2)
|
||||
|
||||
// add jobs concurrently
|
||||
for i := range workers {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := range iterations {
|
||||
at := time.Now().Add(time.Hour).UnixMilli()
|
||||
cs.AddJob(fmt.Sprintf("Job-%d-%d", id, j), CronSchedule{Kind: "at", AtMS: &at}, "", false, "", "")
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// read and update jobs concurrently
|
||||
for range workers {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := range iterations {
|
||||
jobs := cs.ListJobs(true)
|
||||
if len(jobs) > 0 {
|
||||
cs.EnableJob(jobs[0].ID, j%2 == 0)
|
||||
}
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
+17
-13
@@ -51,7 +51,7 @@ func init() {
|
||||
FormatFieldValue: formatFieldValue,
|
||||
}
|
||||
|
||||
logger = zerolog.New(consoleWriter).With().Timestamp().Logger()
|
||||
logger = zerolog.New(consoleWriter).With().Timestamp().Caller().Logger()
|
||||
fileLogger = zerolog.Logger{}
|
||||
})
|
||||
}
|
||||
@@ -94,6 +94,12 @@ func SetLevel(level LogLevel) {
|
||||
zerolog.SetGlobalLevel(level)
|
||||
}
|
||||
|
||||
func SetConsoleLevel(level LogLevel) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
logger = logger.Level(level)
|
||||
}
|
||||
|
||||
func GetLevel() LogLevel {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
@@ -134,9 +140,9 @@ func DisableFileLogging() {
|
||||
fileLogger = zerolog.Logger{}
|
||||
}
|
||||
|
||||
func getCallerInfo() (string, int, string) {
|
||||
func getCallerSkip() int {
|
||||
for i := 2; i < 15; i++ {
|
||||
pc, file, line, ok := runtime.Caller(i)
|
||||
pc, file, _, ok := runtime.Caller(i)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
@@ -158,10 +164,10 @@ func getCallerInfo() (string, int, string) {
|
||||
continue
|
||||
}
|
||||
|
||||
return filepath.Base(file), line, filepath.Base(funcName)
|
||||
return i - 1
|
||||
}
|
||||
|
||||
return "???", 0, "???"
|
||||
return 3
|
||||
}
|
||||
|
||||
//nolint:zerologlint
|
||||
@@ -187,19 +193,16 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
return
|
||||
}
|
||||
|
||||
callerFile, callerLine, callerFunc := getCallerInfo()
|
||||
skip := getCallerSkip()
|
||||
|
||||
event := getEvent(logger, level)
|
||||
|
||||
// Build combined field with component and caller
|
||||
if component != "" {
|
||||
event.Str("caller", fmt.Sprintf("%-6s %s:%d (%s)", component, callerFile, callerLine, callerFunc))
|
||||
} else {
|
||||
event.Str("caller", fmt.Sprintf("<none> %s:%d (%s)", callerFile, callerLine, callerFunc))
|
||||
event.Str("component", component)
|
||||
}
|
||||
|
||||
appendFields(event, fields)
|
||||
event.Msg(message)
|
||||
event.CallerSkipFrame(skip).Msg(message)
|
||||
|
||||
// Also log to file if enabled
|
||||
if fileLogger.GetLevel() != zerolog.NoLevel {
|
||||
@@ -208,9 +211,10 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
if component != "" {
|
||||
fileEvent.Str("component", component)
|
||||
}
|
||||
// fileEvent.Str("caller", fmt.Sprintf("%s:%d (%s)", callerFile, callerLine, callerFunc))
|
||||
|
||||
appendFields(event, fields)
|
||||
fileEvent.Msg(message)
|
||||
appendFields(fileEvent, fields)
|
||||
fileEvent.CallerSkipFrame(skip).Msg(message)
|
||||
}
|
||||
|
||||
if level == FATAL {
|
||||
|
||||
@@ -5,13 +5,15 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func ResolveTargetHome(override string) (string, error) {
|
||||
if override != "" {
|
||||
return ExpandHome(override), nil
|
||||
}
|
||||
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
|
||||
if envHome := os.Getenv(config.EnvHome); envHome != "" {
|
||||
return ExpandHome(envHome), nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
|
||||
@@ -132,11 +132,12 @@ type OpenClawChannels struct {
|
||||
}
|
||||
|
||||
type OpenClawTelegramConfig struct {
|
||||
BotToken *string `json:"botToken"`
|
||||
AllowFrom []string `json:"allowFrom"`
|
||||
GroupPolicy *string `json:"groupPolicy"`
|
||||
DmPolicy *string `json:"dmPolicy"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
BotToken *string `json:"botToken"`
|
||||
AllowFrom []string `json:"allowFrom"`
|
||||
GroupPolicy *string `json:"groupPolicy"`
|
||||
DmPolicy *string `json:"dmPolicy"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
UseMarkdownV2 *bool `json:"useMarkdownV2"`
|
||||
}
|
||||
|
||||
type OpenClawDiscordConfig struct {
|
||||
@@ -645,10 +646,11 @@ type WhatsAppConfig struct {
|
||||
}
|
||||
|
||||
type TelegramConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Token string `json:"token"`
|
||||
Proxy string `json:"proxy"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Token string `json:"token"`
|
||||
Proxy string `json:"proxy"`
|
||||
AllowFrom []string `json:"allow_from"`
|
||||
UseMarkdownV2 bool `json:"use_markdown_v2"`
|
||||
}
|
||||
|
||||
type FeishuConfig struct {
|
||||
@@ -777,9 +779,11 @@ func (c *OpenClawConfig) convertChannels(warnings *[]string) ChannelsConfig {
|
||||
|
||||
if c.Channels.Telegram != nil {
|
||||
enabled := c.Channels.Telegram.Enabled == nil || *c.Channels.Telegram.Enabled
|
||||
useMarkdownV2 := c.Channels.Telegram.UseMarkdownV2 != nil && *c.Channels.Telegram.UseMarkdownV2
|
||||
channels.Telegram = TelegramConfig{
|
||||
Enabled: enabled,
|
||||
AllowFrom: c.Channels.Telegram.AllowFrom,
|
||||
Enabled: enabled,
|
||||
AllowFrom: c.Channels.Telegram.AllowFrom,
|
||||
UseMarkdownV2: useMarkdownV2,
|
||||
}
|
||||
if c.Channels.Telegram.BotToken != nil {
|
||||
channels.Telegram.Token = *c.Channels.Telegram.BotToken
|
||||
|
||||
@@ -10,6 +10,11 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/migrate/internal"
|
||||
)
|
||||
|
||||
// OpenclawHomeEnvVar is the environment variable that overrides the source
|
||||
// openclaw home directory when migrating from openclaw to picoclaw.
|
||||
// Default: ~/.openclaw
|
||||
const OpenclawHomeEnvVar = "OPENCLAW_HOME"
|
||||
|
||||
var providerMapping = map[string]string{
|
||||
"anthropic": "anthropic",
|
||||
"claude": "anthropic",
|
||||
@@ -112,7 +117,7 @@ func resolveSourceHome(override string) (string, error) {
|
||||
if override != "" {
|
||||
return internal.ExpandHome(override), nil
|
||||
}
|
||||
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
|
||||
if envHome := os.Getenv(OpenclawHomeEnvVar); envHome != "" {
|
||||
return internal.ExpandHome(envHome), nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
|
||||
@@ -180,6 +180,10 @@ func buildParams(
|
||||
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Skip tool calls with empty names to avoid API errors
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
args := tc.Arguments
|
||||
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||
|
||||
@@ -50,10 +50,18 @@ func (p *ClaudeCliProvider) Chat(
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if stderrStr := stderr.String(); stderrStr != "" {
|
||||
stderrStr := strings.TrimSpace(stderr.String())
|
||||
stdoutStr := strings.TrimSpace(stdout.String())
|
||||
switch {
|
||||
case stderrStr != "" && stdoutStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %w\nstderr: %s\nstdout: %s", err, stderrStr, stdoutStr)
|
||||
case stderrStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %s", stderrStr)
|
||||
case stdoutStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %w\noutput: %s", err, stdoutStr)
|
||||
default:
|
||||
return nil, fmt.Errorf("claude cli error: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("claude cli error: %w", err)
|
||||
}
|
||||
|
||||
return p.parseClaudeCliResponse(stdout.String())
|
||||
|
||||
@@ -8,6 +8,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CodexHomeEnvVar is the environment variable that overrides the Codex CLI
|
||||
// home directory when resolving the codex auth.json credentials file.
|
||||
// Default: ~/.codex
|
||||
const CodexHomeEnvVar = "CODEX_HOME"
|
||||
|
||||
// CodexCliAuth represents the ~/.codex/auth.json file structure.
|
||||
type CodexCliAuth struct {
|
||||
Tokens struct {
|
||||
@@ -69,7 +74,7 @@ func CreateCodexCliTokenSource() func() (string, string, error) {
|
||||
}
|
||||
|
||||
func resolveCodexAuthPath() (string, error) {
|
||||
codexHome := os.Getenv("CODEX_HOME")
|
||||
codexHome := os.Getenv(CodexHomeEnvVar)
|
||||
if codexHome == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
||||
@@ -95,7 +95,10 @@ func (p *CodexProvider) Chat(
|
||||
)
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
|
||||
// Respect tools.web.prefer_native: only inject native search when the agent
|
||||
// loop requested it (options["native_search"]), so prefer_native: false
|
||||
useNativeSearch := p.enableWebSearch && (options["native_search"] == true)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options, useNativeSearch)
|
||||
|
||||
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
@@ -157,6 +160,10 @@ func (p *CodexProvider) GetDefaultModel() string {
|
||||
return codexDefaultModel
|
||||
}
|
||||
|
||||
func (p *CodexProvider) SupportsNativeSearch() bool {
|
||||
return p.enableWebSearch
|
||||
}
|
||||
|
||||
func resolveCodexModel(model string) (string, string) {
|
||||
m := strings.ToLower(strings.TrimSpace(model))
|
||||
if m == "" {
|
||||
|
||||
@@ -355,7 +355,9 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{"max_tokens": 1024})
|
||||
// Pass native_search so Codex injects built-in web search (mirrors agent loop when prefer_native is true).
|
||||
opts := map[string]any{"max_tokens": 1024, "native_search": true}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", opts)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
|
||||
@@ -55,8 +55,8 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocols: openai, litellm, anthropic, anthropic-messages, antigravity,
|
||||
// claude-cli, codex-cli, github-copilot
|
||||
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages,
|
||||
// antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||
if cfg == nil {
|
||||
@@ -116,7 +116,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
|
||||
"minimax", "longcat", "modelscope":
|
||||
"minimax", "longcat", "modelscope", "novita":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -219,6 +219,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://openrouter.ai/api/v1"
|
||||
case "litellm":
|
||||
return "http://localhost:4000/v1"
|
||||
case "novita":
|
||||
return "https://api.novita.ai/openai"
|
||||
case "groq":
|
||||
return "https://api.groq.com/openai/v1"
|
||||
case "zhipu":
|
||||
|
||||
@@ -112,6 +112,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
}{
|
||||
{"openai", "openai"},
|
||||
{"groq", "groq"},
|
||||
{"novita", "novita"},
|
||||
{"openrouter", "openrouter"},
|
||||
{"cerebras", "cerebras"},
|
||||
{"vivgrid", "vivgrid"},
|
||||
@@ -222,6 +223,34 @@ func TestGetDefaultAPIBase_ModelScope(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Novita(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-novita",
|
||||
Model: "novita/deepseek/deepseek-v3.2",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "deepseek/deepseek-v3.2" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "deepseek/deepseek-v3.2")
|
||||
}
|
||||
if _, ok := provider.(*HTTPProvider); !ok {
|
||||
t.Fatalf("expected *HTTPProvider, got %T", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultAPIBase_Novita(t *testing.T) {
|
||||
if got := getDefaultAPIBase("novita"); got != "https://api.novita.ai/openai" {
|
||||
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "novita", got, "https://api.novita.ai/openai")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-anthropic",
|
||||
|
||||
@@ -55,3 +55,7 @@ func (p *HTTPProvider) Chat(
|
||||
func (p *HTTPProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) SupportsNativeSearch() bool {
|
||||
return p.delegate.SupportsNativeSearch()
|
||||
}
|
||||
|
||||
@@ -103,8 +103,11 @@ func (p *Provider) Chat(
|
||||
"messages": common.SerializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
// When fallback uses a different provider (e.g. DeepSeek), that provider must not inject web_search_preview.
|
||||
nativeSearch, _ := options["native_search"].(bool)
|
||||
nativeSearch = nativeSearch && isNativeSearchHost(p.apiBase)
|
||||
if len(tools) > 0 || nativeSearch {
|
||||
requestBody["tools"] = buildToolsList(tools, nativeSearch)
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
@@ -188,13 +191,40 @@ func normalizeModel(model, apiBase string) string {
|
||||
prefix := strings.ToLower(before)
|
||||
switch prefix {
|
||||
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google",
|
||||
"openrouter", "zhipu", "mistral", "vivgrid", "minimax":
|
||||
"openrouter", "zhipu", "mistral", "vivgrid", "minimax", "novita":
|
||||
return after
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolsList(tools []ToolDefinition, nativeSearch bool) []any {
|
||||
result := make([]any, 0, len(tools)+1)
|
||||
for _, t := range tools {
|
||||
if nativeSearch && strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
result = append(result, t)
|
||||
}
|
||||
if nativeSearch {
|
||||
result = append(result, map[string]any{"type": "web_search_preview"})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *Provider) SupportsNativeSearch() bool {
|
||||
return isNativeSearchHost(p.apiBase)
|
||||
}
|
||||
|
||||
func isNativeSearchHost(apiBase string) bool {
|
||||
u, err := url.Parse(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := u.Hostname()
|
||||
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
|
||||
@@ -432,7 +432,28 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
|
||||
func TestProviderChat_StripsGroqOllamaDeepseekVivgridNovitaPrefixes(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
@@ -463,31 +484,25 @@ func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
|
||||
input: "vivgrid/auto",
|
||||
wantModel: "auto",
|
||||
},
|
||||
{
|
||||
name: "strips novita prefix deepseek model",
|
||||
input: "novita/deepseek/deepseek-v3.2",
|
||||
wantModel: "deepseek/deepseek-v3.2",
|
||||
},
|
||||
{
|
||||
name: "strips novita prefix zai model",
|
||||
input: "novita/zai-org/glm-5",
|
||||
wantModel: "zai-org/glm-5",
|
||||
},
|
||||
{
|
||||
name: "strips novita prefix minimax model",
|
||||
input: "novita/minimax/minimax-m2.5",
|
||||
wantModel: "minimax/minimax-m2.5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -573,6 +588,12 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
if got := normalizeModel("vivgrid/auto", "https://api.vivgrid.com/v1"); got != "auto" {
|
||||
t.Fatalf("normalizeModel(vivgrid auto) = %q, want %q", got, "auto")
|
||||
}
|
||||
if got := normalizeModel(
|
||||
"novita/deepseek/deepseek-v3.2",
|
||||
"https://api.novita.ai/openai",
|
||||
); got != "deepseek/deepseek-v3.2" {
|
||||
t.Fatalf("normalizeModel(novita) = %q, want %q", got, "deepseek/deepseek-v3.2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_RequestTimeoutDefault(t *testing.T) {
|
||||
@@ -824,6 +845,232 @@ func TestSupportsPromptCacheKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolsList_NativeSearchAddsWebSearchPreview(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
|
||||
}
|
||||
result := buildToolsList(tools, true)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
wsEntry, ok := result[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("web search entry is %T, want map[string]any", result[1])
|
||||
}
|
||||
if wsEntry["type"] != "web_search_preview" {
|
||||
t.Fatalf("type = %v, want web_search_preview", wsEntry["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolsList_NativeSearchFiltersClientWebSearch(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
|
||||
}
|
||||
result := buildToolsList(tools, true)
|
||||
for _, entry := range result {
|
||||
if td, ok := entry.(ToolDefinition); ok && strings.EqualFold(td.Function.Name, "web_search") {
|
||||
t.Fatal("client-side web_search should be filtered out when native search is enabled")
|
||||
}
|
||||
}
|
||||
if len(result) != 2 { // read_file + web_search_preview
|
||||
t.Fatalf("len(result) = %d, want 2 (read_file + web_search_preview)", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolsList_NoNativeSearchPassesThrough(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
|
||||
}
|
||||
result := buildToolsList(tools, false)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNativeSearchHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
apiBase string
|
||||
want bool
|
||||
}{
|
||||
{"https://api.openai.com/v1", true},
|
||||
{"https://myresource.openai.azure.com/openai/deployments/gpt-4", true},
|
||||
{"https://api.mistral.ai/v1", false},
|
||||
{"https://api.deepseek.com/v1", false},
|
||||
{"https://api.groq.com/openai/v1", false},
|
||||
{"http://localhost:11434/v1", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isNativeSearchHost(tt.apiBase); got != tt.want {
|
||||
t.Errorf("isNativeSearchHost(%q) = %v, want %v", tt.apiBase, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsNativeSearch_OpenAI(t *testing.T) {
|
||||
p := NewProvider("key", "https://api.openai.com/v1", "")
|
||||
if !p.SupportsNativeSearch() {
|
||||
t.Fatal("OpenAI provider should support native search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsNativeSearch_NonOpenAI(t *testing.T) {
|
||||
p := NewProvider("key", "https://api.deepseek.com/v1", "")
|
||||
if p.SupportsNativeSearch() {
|
||||
t.Fatal("DeepSeek provider should not support native search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_NativeSearchToolInjected(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
p.apiBase = "https://api.openai.com/v1"
|
||||
p.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
r.URL, _ = url.Parse(server.URL + r.URL.Path)
|
||||
return http.DefaultTransport.RoundTrip(r)
|
||||
}),
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
|
||||
}
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
tools,
|
||||
"gpt-5.4",
|
||||
map[string]any{"native_search": true},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
toolsRaw, ok := requestBody["tools"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("tools is %T, want []any", requestBody["tools"])
|
||||
}
|
||||
if len(toolsRaw) != 2 {
|
||||
t.Fatalf("len(tools) = %d, want 2 (read_file + web_search_preview)", len(toolsRaw))
|
||||
}
|
||||
|
||||
lastTool, ok := toolsRaw[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("last tool is %T, want map[string]any", toolsRaw[1])
|
||||
}
|
||||
if lastTool["type"] != "web_search_preview" {
|
||||
t.Fatalf("last tool type = %v, want web_search_preview", lastTool["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_NativeSearchNotInjectedWithoutOption(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
|
||||
}
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
tools,
|
||||
"gpt-5.4",
|
||||
map[string]any{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
toolsRaw, ok := requestBody["tools"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("tools is %T, want []any", requestBody["tools"])
|
||||
}
|
||||
if len(toolsRaw) != 1 {
|
||||
t.Fatalf("len(tools) = %d, want 1 (web_search only)", len(toolsRaw))
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderChat_NativeSearchIgnoredOnNonOpenAI verifies that when native_search
|
||||
// is true in options but the provider's apiBase is not OpenAI (e.g. fallback to DeepSeek),
|
||||
// we do not inject web_search_preview to avoid API errors.
|
||||
func TestProviderChat_NativeSearchIgnoredOnNonOpenAI(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Use server.URL so host is not api.openai.com — simulates DeepSeek/other provider
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"deepseek-chat",
|
||||
map[string]any{"native_search": true},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// Should not have tools at all (no tools passed, and we must not add web_search_preview)
|
||||
if toolsRaw, ok := requestBody["tools"]; ok {
|
||||
t.Fatalf("tools should be omitted for non-OpenAI when only native_search was requested, got %v", toolsRaw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{
|
||||
|
||||
@@ -44,6 +44,15 @@ type ThinkingCapable interface {
|
||||
SupportsThinking() bool
|
||||
}
|
||||
|
||||
// NativeSearchCapable is an optional interface for providers that support
|
||||
// built-in web search during LLM inference (e.g. OpenAI web_search_preview,
|
||||
// xAI Grok search). When the active provider implements this interface and
|
||||
// returns true, the agent loop can hide the client-side web_search tool to
|
||||
// avoid duplicate search surfaces and use the provider's native search instead.
|
||||
type NativeSearchCapable interface {
|
||||
SupportsNativeSearch() bool
|
||||
}
|
||||
|
||||
// FailoverReason classifies why an LLM request failed for fallback decisions.
|
||||
type FailoverReason string
|
||||
|
||||
|
||||
@@ -226,9 +226,12 @@ func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
var msg bus.OutboundMessage
|
||||
select {
|
||||
case msg = <-tool.msgBus.OutboundChan():
|
||||
// got message
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for outbound message")
|
||||
}
|
||||
if !strings.Contains(msg.Content, "command execution is disabled") {
|
||||
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
|
||||
|
||||
+174
-25
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -776,22 +778,49 @@ type WebFetchTool struct {
|
||||
maxChars int
|
||||
proxy string
|
||||
client *http.Client
|
||||
format string
|
||||
fetchLimitBytes int64
|
||||
whitelist *privateHostWhitelist
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
type privateHostWhitelist struct {
|
||||
exact map[string]struct{}
|
||||
cidrs []*net.IPNet
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int, format string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
// createHTTPClient cannot fail with an empty proxy string.
|
||||
return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
|
||||
return NewWebFetchToolWithConfig(maxChars, "", format, fetchLimitBytes, nil)
|
||||
}
|
||||
|
||||
// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed.
|
||||
// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily.
|
||||
var allowPrivateWebFetchHosts atomic.Bool
|
||||
|
||||
func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
|
||||
func NewWebFetchToolWithProxy(
|
||||
maxChars int,
|
||||
proxy string,
|
||||
format string,
|
||||
fetchLimitBytes int64,
|
||||
privateHostWhitelist []string,
|
||||
) (*WebFetchTool, error) {
|
||||
return NewWebFetchToolWithConfig(maxChars, proxy, format, fetchLimitBytes, privateHostWhitelist)
|
||||
}
|
||||
|
||||
func NewWebFetchToolWithConfig(
|
||||
maxChars int,
|
||||
proxy string,
|
||||
format string,
|
||||
fetchLimitBytes int64,
|
||||
privateHostWhitelist []string,
|
||||
) (*WebFetchTool, error) {
|
||||
if maxChars <= 0 {
|
||||
maxChars = defaultMaxChars
|
||||
}
|
||||
whitelist, err := newPrivateHostWhitelist(privateHostWhitelist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse web fetch private host whitelist: %w", err)
|
||||
}
|
||||
client, err := utils.CreateHTTPClient(proxy, fetchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
|
||||
@@ -801,13 +830,13 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
transport.DialContext = newSafeDialContext(dialer)
|
||||
transport.DialContext = newSafeDialContext(dialer, whitelist)
|
||||
}
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
if isObviousPrivateHost(req.URL.Hostname()) {
|
||||
if isObviousPrivateHost(req.URL.Hostname(), whitelist) {
|
||||
return fmt.Errorf("redirect target is private or local network host")
|
||||
}
|
||||
return nil
|
||||
@@ -819,7 +848,9 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64)
|
||||
maxChars: maxChars,
|
||||
proxy: proxy,
|
||||
client: client,
|
||||
format: format,
|
||||
fetchLimitBytes: fetchLimitBytes,
|
||||
whitelist: whitelist,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -871,7 +902,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
// Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution.
|
||||
// The real SSRF guard is newSafeDialContext at connect time.
|
||||
hostname := parsedURL.Hostname()
|
||||
if isObviousPrivateHost(hostname) {
|
||||
if isObviousPrivateHost(hostname, t.whitelist) {
|
||||
return ErrorResult("fetching private or local network hosts is not allowed")
|
||||
}
|
||||
|
||||
@@ -906,26 +937,68 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
|
||||
mediaType, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
// The most common error here is "mime: no media type" if the header is empty.
|
||||
logger.WarnCF("tool", "Failed to parse Content-Type", map[string]any{
|
||||
"raw_header": contentType,
|
||||
"error": err.Error(),
|
||||
})
|
||||
|
||||
// security fallback
|
||||
mediaType = "application/octet-stream"
|
||||
}
|
||||
|
||||
charset, hasCharset := params["charset"]
|
||||
if hasCharset {
|
||||
// If the charset is not utf-8, we might have to convert the bodyStr
|
||||
// before passing it to the HTML/Markdown parser
|
||||
if strings.ToLower(charset) != "utf-8" {
|
||||
logger.WarnCF("tool", "Note: the content is not in UTF-8", map[string]any{"charset": charset})
|
||||
}
|
||||
}
|
||||
|
||||
var text, extractor string
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
switch {
|
||||
case mediaType == "application/json":
|
||||
var jsonData any
|
||||
if err := json.Unmarshal(body, &jsonData); err == nil {
|
||||
formatted, _ := json.MarshalIndent(jsonData, "", " ")
|
||||
text = string(formatted)
|
||||
extractor = "json"
|
||||
} else {
|
||||
text = string(body)
|
||||
if err := json.Unmarshal(body, &jsonData); err != nil {
|
||||
text = bodyStr
|
||||
extractor = "raw"
|
||||
break
|
||||
}
|
||||
} else if strings.Contains(contentType, "text/html") || len(body) > 0 &&
|
||||
(strings.HasPrefix(string(body), "<!DOCTYPE") || strings.HasPrefix(strings.ToLower(string(body)), "<html")) {
|
||||
text = t.extractText(string(body))
|
||||
extractor = "text"
|
||||
} else {
|
||||
text = string(body)
|
||||
|
||||
formatted, err := json.MarshalIndent(jsonData, "", " ")
|
||||
if err != nil {
|
||||
text = bodyStr
|
||||
extractor = "raw"
|
||||
break
|
||||
}
|
||||
|
||||
text = string(formatted)
|
||||
extractor = "json"
|
||||
|
||||
case mediaType == "text/html" || looksLikeHTML(bodyStr):
|
||||
switch strings.ToLower(t.format) {
|
||||
case "markdown":
|
||||
var err error
|
||||
text, err = utils.HtmlToMarkdown(bodyStr)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to HTML to markdown: %v", err))
|
||||
}
|
||||
extractor = "markdown"
|
||||
|
||||
default:
|
||||
text = t.extractText(bodyStr)
|
||||
extractor = "text"
|
||||
}
|
||||
|
||||
default:
|
||||
text = bodyStr
|
||||
extractor = "raw"
|
||||
}
|
||||
|
||||
@@ -957,6 +1030,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
}
|
||||
}
|
||||
|
||||
func looksLikeHTML(body string) bool {
|
||||
if body == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
lower := strings.ToLower(body)
|
||||
|
||||
return strings.HasPrefix(body, "<!doctype") ||
|
||||
strings.HasPrefix(lower, "<html")
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
result := reScript.ReplaceAllLiteralString(htmlContent, "")
|
||||
result = reStyle.ReplaceAllLiteralString(result, "")
|
||||
@@ -981,7 +1065,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU)
|
||||
// where a hostname resolves to a public IP during pre-flight but a private IP at connect time.
|
||||
func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
|
||||
func newSafeDialContext(
|
||||
dialer *net.Dialer,
|
||||
whitelist *privateHostWhitelist,
|
||||
) func(context.Context, string, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return dialer.DialContext(ctx, network, address)
|
||||
@@ -996,7 +1083,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateOrRestrictedIP(ip) {
|
||||
if shouldBlockPrivateIP(ip, whitelist) {
|
||||
return nil, fmt.Errorf("blocked private or local target: %s", host)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
@@ -1010,7 +1097,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
attempted := 0
|
||||
var lastErr error
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if isPrivateOrRestrictedIP(ipAddr.IP) {
|
||||
if shouldBlockPrivateIP(ipAddr.IP, whitelist) {
|
||||
continue
|
||||
}
|
||||
attempted++
|
||||
@@ -1022,7 +1109,7 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
|
||||
if attempted == 0 {
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host)
|
||||
return nil, fmt.Errorf("all resolved addresses for %s are private, restricted, or not whitelisted", host)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr)
|
||||
@@ -1031,10 +1118,72 @@ func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string
|
||||
}
|
||||
}
|
||||
|
||||
func newPrivateHostWhitelist(entries []string) (*privateHostWhitelist, error) {
|
||||
if len(entries) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
whitelist := &privateHostWhitelist{
|
||||
exact: make(map[string]struct{}),
|
||||
cidrs: make([]*net.IPNet, 0, len(entries)),
|
||||
}
|
||||
for _, entry := range entries {
|
||||
entry = strings.TrimSpace(entry)
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if ip := net.ParseIP(entry); ip != nil {
|
||||
whitelist.exact[normalizeWhitelistIP(ip).String()] = struct{}{}
|
||||
continue
|
||||
}
|
||||
_, network, err := net.ParseCIDR(entry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid entry %q: expected IP or CIDR", entry)
|
||||
}
|
||||
whitelist.cidrs = append(whitelist.cidrs, network)
|
||||
}
|
||||
|
||||
if len(whitelist.exact) == 0 && len(whitelist.cidrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return whitelist, nil
|
||||
}
|
||||
|
||||
func (w *privateHostWhitelist) Contains(ip net.IP) bool {
|
||||
if w == nil || ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
normalized := normalizeWhitelistIP(ip)
|
||||
if _, ok := w.exact[normalized.String()]; ok {
|
||||
return true
|
||||
}
|
||||
for _, network := range w.cidrs {
|
||||
if network.Contains(normalized) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeWhitelistIP(ip net.IP) net.IP {
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func shouldBlockPrivateIP(ip net.IP, whitelist *privateHostWhitelist) bool {
|
||||
return isPrivateOrRestrictedIP(ip) && !whitelist.Contains(ip)
|
||||
}
|
||||
|
||||
// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts.
|
||||
// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS —
|
||||
// the real SSRF guard is newSafeDialContext which checks IPs at connect time.
|
||||
func isObviousPrivateHost(host string) bool {
|
||||
func isObviousPrivateHost(host string, whitelist *privateHostWhitelist) bool {
|
||||
if allowPrivateWebFetchHosts.Load() {
|
||||
return false
|
||||
}
|
||||
@@ -1050,7 +1199,7 @@ func isObviousPrivateHost(host string) bool {
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
return isPrivateOrRestrictedIP(ip)
|
||||
return shouldBlockPrivateIP(ip, whitelist)
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
+170
-20
@@ -10,11 +10,15 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const testFetchLimit = int64(10 * 1024 * 1024)
|
||||
const (
|
||||
testFetchLimit = int64(10 * 1024 * 1024)
|
||||
format = "plaintext"
|
||||
)
|
||||
|
||||
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
||||
func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
@@ -27,7 +31,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -69,7 +73,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -94,7 +98,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
|
||||
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -119,7 +123,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
|
||||
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -144,7 +148,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
|
||||
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -178,7 +182,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars
|
||||
tool, err := NewWebFetchTool(1000, format, testFetchLimit) // Limit to 1000 chars
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -228,7 +232,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
defer ts.Close()
|
||||
|
||||
// Initialize the tool
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -311,7 +315,7 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -423,8 +427,31 @@ func withPrivateWebFetchHostsAllowed(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
|
||||
t.Helper()
|
||||
hostPort := strings.TrimPrefix(rawURL, "http://")
|
||||
hostPort = strings.TrimPrefix(hostPort, "https://")
|
||||
host, port, err := net.SplitHostPort(hostPort)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func singleHostCIDR(t *testing.T, host string) string {
|
||||
t.Helper()
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", host)
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
return ip.String() + "/32"
|
||||
}
|
||||
return ip.String() + "/128"
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -441,6 +468,56 @@ func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("exact whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{host})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("cidr whitelist ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
host, _ := serverHostAndPort(t, server.URL)
|
||||
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{singleHostCIDR(t, host)})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"url": server.URL,
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
|
||||
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
withPrivateWebFetchHostsAllowed(t)
|
||||
|
||||
@@ -451,7 +528,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -466,7 +543,7 @@ func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
||||
|
||||
// TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked
|
||||
func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -481,7 +558,7 @@ func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
|
||||
|
||||
// TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked
|
||||
func TestWebFetch_BlocksMetadataIP(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -496,7 +573,7 @@ func TestWebFetch_BlocksMetadataIP(t *testing.T) {
|
||||
|
||||
// TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked
|
||||
func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -511,7 +588,7 @@ func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
|
||||
|
||||
// TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked
|
||||
func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -527,7 +604,7 @@ func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
|
||||
|
||||
// TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked
|
||||
func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -557,7 +634,7 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
|
||||
allowPrivateWebFetchHosts.Store(false)
|
||||
defer allowPrivateWebFetchHosts.Store(true)
|
||||
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create web fetch tool: %v", err)
|
||||
}
|
||||
@@ -570,6 +647,69 @@ func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
|
||||
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err == nil {
|
||||
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on loopback: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
accepted := make(chan struct{}, 1)
|
||||
go func() {
|
||||
conn, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
accepted <- struct{}{}
|
||||
}()
|
||||
|
||||
_, port, err := net.SplitHostPort(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to split listener address: %v", err)
|
||||
}
|
||||
|
||||
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse whitelist: %v", err)
|
||||
}
|
||||
|
||||
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
|
||||
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
||||
if err != nil {
|
||||
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
select {
|
||||
case <-accepted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected localhost listener to accept a connection")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
|
||||
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -615,7 +755,7 @@ func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
||||
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -639,7 +779,7 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit)
|
||||
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", format, testFetchLimit, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else if tool.maxChars != 1024 {
|
||||
@@ -650,7 +790,7 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
|
||||
}
|
||||
|
||||
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit)
|
||||
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", format, testFetchLimit, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
}
|
||||
@@ -660,6 +800,16 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
|
||||
_, err := NewWebFetchToolWithConfig(1024, "", format, testFetchLimit, []string{"not-an-ip-or-cidr"})
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid whitelist entry to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid entry") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
|
||||
@@ -0,0 +1,411 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
var (
|
||||
reSpaces = regexp.MustCompile(`[ \t]+`)
|
||||
reNewlines = regexp.MustCompile(`\n{3,}`)
|
||||
reEmptyListItem = regexp.MustCompile(`(?m)^[-*]\s*$`)
|
||||
reImageOnlyLink = regexp.MustCompile(`\[!\[\]\(<[^>]*>\)\]\(<[^>]*>\)`)
|
||||
reEmptyHeader = regexp.MustCompile(`(?m)^#{1,6}\s*$`)
|
||||
reLeadingLineSpace = regexp.MustCompile(`(?m)^([ \t])([^ \t\n])`)
|
||||
)
|
||||
|
||||
var skipTags = map[string]bool{
|
||||
"script": true, "style": true, "head": true,
|
||||
"noscript": true, "template": true,
|
||||
"nav": true, "footer": true, "aside": true, "header": true, "form": true, "dialog": true,
|
||||
}
|
||||
|
||||
func isSafeHref(href string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(href))
|
||||
if strings.HasPrefix(lower, "javascript:") || strings.HasPrefix(lower, "vbscript:") ||
|
||||
strings.HasPrefix(lower, "data:") {
|
||||
return false
|
||||
}
|
||||
u, err := url.Parse(strings.TrimSpace(href))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
scheme := strings.ToLower(u.Scheme)
|
||||
return scheme == "" || scheme == "http" || scheme == "https" || scheme == "mailto"
|
||||
}
|
||||
|
||||
func isSafeImageSrc(src string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(src))
|
||||
if strings.HasPrefix(lower, "data:image/") {
|
||||
return true
|
||||
}
|
||||
return isSafeHref(src)
|
||||
}
|
||||
|
||||
func escapeMdAlt(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `[`, `\[`)
|
||||
s = strings.ReplaceAll(s, `]`, `\]`)
|
||||
return s
|
||||
}
|
||||
|
||||
func getAttr(n *html.Node, key string) string {
|
||||
for _, a := range n.Attr {
|
||||
if a.Key == key {
|
||||
return a.Val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeAttr(val string) string {
|
||||
val = strings.ReplaceAll(val, "\n", "")
|
||||
val = strings.ReplaceAll(val, "\r", "")
|
||||
val = strings.ReplaceAll(val, "\t", "")
|
||||
return strings.TrimSpace(val)
|
||||
}
|
||||
|
||||
func isUnlikelyNode(n *html.Node) bool {
|
||||
if n.Type != html.ElementNode {
|
||||
return false
|
||||
}
|
||||
classId := strings.ToLower(getAttr(n, "class") + " " + getAttr(n, "id"))
|
||||
if classId == " " {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(classId, "article") || strings.Contains(classId, "main") ||
|
||||
strings.Contains(classId, "content") {
|
||||
return false
|
||||
}
|
||||
unlikelyKeywords := []string{
|
||||
"menu",
|
||||
"nav",
|
||||
"footer",
|
||||
"sidebar",
|
||||
"cookie",
|
||||
"banner",
|
||||
"sponsor",
|
||||
"advert",
|
||||
"popup",
|
||||
"modal",
|
||||
"newsletter",
|
||||
"share",
|
||||
"social",
|
||||
}
|
||||
for _, keyword := range unlikelyKeywords {
|
||||
if strings.Contains(classId, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type converter struct {
|
||||
stack []*bytes.Buffer
|
||||
linkHrefs []string
|
||||
linkStates []bool
|
||||
emphStack []string // Tracks "**", "*", "~~" for buffered emphasis
|
||||
olCounters []int
|
||||
inPre bool
|
||||
listDepth int
|
||||
}
|
||||
|
||||
func newConverter() *converter {
|
||||
return &converter{
|
||||
stack: []*bytes.Buffer{{}},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) write(s string) {
|
||||
c.stack[len(c.stack)-1].WriteString(s)
|
||||
}
|
||||
|
||||
func (c *converter) pushBuf() {
|
||||
c.stack = append(c.stack, &bytes.Buffer{})
|
||||
}
|
||||
|
||||
func (c *converter) popBuf() string {
|
||||
top := c.stack[len(c.stack)-1]
|
||||
c.stack = c.stack[:len(c.stack)-1]
|
||||
return top.String()
|
||||
}
|
||||
|
||||
func (c *converter) walk(n *html.Node) {
|
||||
if n.Type == html.ElementNode {
|
||||
if skipTags[n.Data] {
|
||||
return
|
||||
}
|
||||
if isUnlikelyNode(n) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if n.Type == html.TextNode {
|
||||
text := n.Data
|
||||
if !c.inPre {
|
||||
text = strings.ReplaceAll(text, "\n", " ")
|
||||
text = reSpaces.ReplaceAllString(text, " ")
|
||||
}
|
||||
if text != "" {
|
||||
c.write(text)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n.Type != html.ElementNode {
|
||||
for ch := n.FirstChild; ch != nil; ch = ch.NextSibling {
|
||||
c.walk(ch)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Opening Tags
|
||||
switch n.Data {
|
||||
// Buffer emphasis content so we can TrimSpace the inner text,
|
||||
// avoiding the regex-across-boundaries bug.
|
||||
case "b", "strong":
|
||||
c.emphStack = append(c.emphStack, "**")
|
||||
c.pushBuf()
|
||||
case "i", "em":
|
||||
c.emphStack = append(c.emphStack, "*")
|
||||
c.pushBuf()
|
||||
case "del", "s":
|
||||
c.emphStack = append(c.emphStack, "~~")
|
||||
c.pushBuf()
|
||||
|
||||
case "a":
|
||||
href := normalizeAttr(getAttr(n, "href"))
|
||||
if href != "" && !isSafeHref(href) {
|
||||
href = "#"
|
||||
}
|
||||
hasHref := href != ""
|
||||
c.linkStates = append(c.linkStates, hasHref)
|
||||
if hasHref {
|
||||
c.linkHrefs = append(c.linkHrefs, href)
|
||||
c.pushBuf()
|
||||
}
|
||||
|
||||
case "h1":
|
||||
c.write("\n\n# ")
|
||||
case "h2":
|
||||
c.write("\n\n## ")
|
||||
case "h3":
|
||||
c.write("\n\n### ")
|
||||
case "h4":
|
||||
c.write("\n\n#### ")
|
||||
case "h5":
|
||||
c.write("\n\n##### ")
|
||||
case "h6":
|
||||
c.write("\n\n###### ")
|
||||
|
||||
case "p":
|
||||
c.write("\n\n")
|
||||
case "br":
|
||||
c.write("\n")
|
||||
case "hr":
|
||||
c.write("\n\n---\n\n")
|
||||
|
||||
case "ol":
|
||||
c.olCounters = append(c.olCounters, 1)
|
||||
// Only write leading newline for top-level list.
|
||||
if c.listDepth == 0 {
|
||||
c.write("\n")
|
||||
}
|
||||
c.listDepth++
|
||||
case "ul":
|
||||
if c.listDepth == 0 {
|
||||
c.write("\n")
|
||||
}
|
||||
c.listDepth++
|
||||
case "li":
|
||||
c.write("\n")
|
||||
if c.listDepth > 1 {
|
||||
c.write(strings.Repeat(" ", c.listDepth-1))
|
||||
}
|
||||
if n.Parent != nil && n.Parent.Data == "ol" && len(c.olCounters) > 0 {
|
||||
idx := c.olCounters[len(c.olCounters)-1]
|
||||
c.write(strconv.Itoa(idx) + ". ")
|
||||
c.olCounters[len(c.olCounters)-1]++
|
||||
} else {
|
||||
c.write("- ")
|
||||
}
|
||||
|
||||
case "pre":
|
||||
c.inPre = true
|
||||
c.write("\n\n```\n")
|
||||
case "code":
|
||||
if !c.inPre {
|
||||
c.write("`")
|
||||
}
|
||||
|
||||
case "blockquote":
|
||||
c.pushBuf()
|
||||
for ch := n.FirstChild; ch != nil; ch = ch.NextSibling {
|
||||
c.walk(ch)
|
||||
}
|
||||
inner := strings.TrimSpace(c.popBuf())
|
||||
lines := strings.Split(inner, "\n")
|
||||
var quoted []string
|
||||
for _, l := range lines {
|
||||
if strings.TrimSpace(l) == "" {
|
||||
quoted = append(quoted, ">")
|
||||
} else {
|
||||
quoted = append(quoted, "> "+l)
|
||||
}
|
||||
}
|
||||
var deduped []string
|
||||
for i, line := range quoted {
|
||||
if line == ">" && i > 0 && deduped[len(deduped)-1] == ">" {
|
||||
continue
|
||||
}
|
||||
deduped = append(deduped, line)
|
||||
}
|
||||
c.write("\n\n" + strings.Join(deduped, "\n") + "\n\n")
|
||||
return
|
||||
|
||||
case "img":
|
||||
src := normalizeAttr(getAttr(n, "src"))
|
||||
if src == "" {
|
||||
src = normalizeAttr(getAttr(n, "data-src"))
|
||||
}
|
||||
if src == "" {
|
||||
return
|
||||
}
|
||||
alt := escapeMdAlt(normalizeAttr(getAttr(n, "alt")))
|
||||
if isSafeImageSrc(src) {
|
||||
c.write("")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Traverse Children
|
||||
for ch := n.FirstChild; ch != nil; ch = ch.NextSibling {
|
||||
c.walk(ch)
|
||||
}
|
||||
|
||||
// Closing Tags
|
||||
switch n.Data {
|
||||
// Pop buffer, trim, wrap with the correct marker.
|
||||
case "b", "strong", "i", "em", "del", "s":
|
||||
if len(c.emphStack) == 0 {
|
||||
break
|
||||
}
|
||||
marker := c.emphStack[len(c.emphStack)-1]
|
||||
c.emphStack = c.emphStack[:len(c.emphStack)-1]
|
||||
inner := strings.TrimSpace(c.popBuf())
|
||||
if inner != "" {
|
||||
c.write(marker + inner + marker)
|
||||
}
|
||||
|
||||
case "a":
|
||||
if len(c.linkStates) == 0 {
|
||||
break
|
||||
}
|
||||
hasHref := c.linkStates[len(c.linkStates)-1]
|
||||
c.linkStates = c.linkStates[:len(c.linkStates)-1]
|
||||
if !hasHref {
|
||||
break
|
||||
}
|
||||
href := c.linkHrefs[len(c.linkHrefs)-1]
|
||||
c.linkHrefs = c.linkHrefs[:len(c.linkHrefs)-1]
|
||||
inner := strings.TrimSpace(c.popBuf())
|
||||
if strings.Contains(inner, "\n") {
|
||||
lines := strings.Split(inner, "\n")
|
||||
linked := false
|
||||
for i, l := range lines {
|
||||
cleanLine := strings.TrimSpace(l)
|
||||
if cleanLine != "" && !strings.HasPrefix(cleanLine, "![") && !linked {
|
||||
lines[i] = "[" + cleanLine + "](" + href + ")"
|
||||
linked = true
|
||||
}
|
||||
}
|
||||
c.write(strings.Join(lines, "\n"))
|
||||
} else {
|
||||
c.write("[" + inner + "](" + href + ")")
|
||||
}
|
||||
|
||||
case "h1",
|
||||
"h2",
|
||||
"h3",
|
||||
"h4",
|
||||
"h5",
|
||||
"h6",
|
||||
"p",
|
||||
"div",
|
||||
"section",
|
||||
"article",
|
||||
"header",
|
||||
"footer",
|
||||
"aside",
|
||||
"nav",
|
||||
"figure":
|
||||
c.write("\n")
|
||||
|
||||
case "ol":
|
||||
c.listDepth--
|
||||
if len(c.olCounters) > 0 {
|
||||
c.olCounters = c.olCounters[:len(c.olCounters)-1]
|
||||
}
|
||||
if c.listDepth == 0 {
|
||||
c.write("\n")
|
||||
}
|
||||
case "ul":
|
||||
c.listDepth--
|
||||
if c.listDepth == 0 {
|
||||
c.write("\n")
|
||||
}
|
||||
|
||||
case "pre":
|
||||
c.inPre = false
|
||||
c.write("\n```\n\n")
|
||||
case "code":
|
||||
if !c.inPre {
|
||||
c.write("`")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func HtmlToMarkdown(htmlStr string) (string, error) {
|
||||
doc, err := html.Parse(strings.NewReader(htmlStr))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
c := newConverter()
|
||||
c.walk(doc)
|
||||
|
||||
res := c.stack[0].String()
|
||||
|
||||
// Post-processing
|
||||
res = reImageOnlyLink.ReplaceAllString(res, "")
|
||||
res = reEmptyListItem.ReplaceAllString(res, "")
|
||||
res = reEmptyHeader.ReplaceAllString(res, "")
|
||||
|
||||
lines := strings.Split(res, "\n")
|
||||
var cleanLines []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, " \t")
|
||||
cleanTest := strings.TrimSpace(line)
|
||||
if cleanTest == "[](</>)" || cleanTest == "[](#)" || cleanTest == "-" {
|
||||
cleanLines = append(cleanLines, "")
|
||||
continue
|
||||
}
|
||||
cleanLines = append(cleanLines, line)
|
||||
}
|
||||
res = strings.Join(cleanLines, "\n")
|
||||
|
||||
res = strings.TrimSpace(res)
|
||||
res = reNewlines.ReplaceAllString(res, "\n\n")
|
||||
|
||||
// Strip a single leading space from lines that are NOT list indentation.
|
||||
// "(?m)^([ \t])([^ \t\n])" matches exactly one space/tab at line start followed
|
||||
// by a non-whitespace char, so " - nested" (4 spaces) is left untouched.
|
||||
res = reLeadingLineSpace.ReplaceAllString(res, "$2")
|
||||
|
||||
return res, nil
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func TestHtmlToMarkdown(t *testing.T) {
|
||||
// Define our test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Removes scripts and styles",
|
||||
input: `<script>alert("hello");</script><style>body { color: red; }</style><p>Clean text</p>`,
|
||||
expected: "Clean text",
|
||||
},
|
||||
{
|
||||
name: "Extracts links correctly",
|
||||
input: `Visit my <a href="https://example.com">website</a> for info.`,
|
||||
expected: "Visit my [website](https://example.com) for info.",
|
||||
},
|
||||
{
|
||||
name: "Converts headers (H1, H2, H3)",
|
||||
input: `<h1>Main Title</h1><h2>Subtitle</h2><h3>Section</h3>`,
|
||||
expected: "# Main Title\n\n## Subtitle\n\n### Section",
|
||||
},
|
||||
{
|
||||
name: "Handles bold and italics",
|
||||
input: `Text <b>bold</b> and <strong>strong</strong>, then <i>italic</i> and <em>em</em>.`,
|
||||
expected: "Text **bold** and **strong**, then *italic* and *em*.",
|
||||
},
|
||||
{
|
||||
name: "Converts lists",
|
||||
input: `<ul><li>First element</li><li>Second element</li></ul>`,
|
||||
expected: "- First element\n- Second element",
|
||||
},
|
||||
{
|
||||
name: "Handles paragraphs and line breaks (<br>)",
|
||||
input: `<p>First paragraph</p><p>Second paragraph with<br>a line break.</p>`,
|
||||
expected: "First paragraph\n\nSecond paragraph with\na line break.",
|
||||
},
|
||||
{
|
||||
name: "Decodes HTML entities",
|
||||
input: `Math: 5 > 3 & 2 < 4. A "quote".`,
|
||||
expected: "Math: 5 > 3 & 2 < 4. A \"quote\".",
|
||||
},
|
||||
{
|
||||
name: "Cleans up residual HTML tags",
|
||||
input: `<div><span>Text inside div and span</span></div>`,
|
||||
expected: "Text inside div and span",
|
||||
},
|
||||
{
|
||||
name: "Removes multiple spaces and excessive empty lines",
|
||||
input: `This text has too many spaces. <br><br><br><br> And too many newlines.`,
|
||||
expected: "This text has too many spaces.\n\nAnd too many newlines.",
|
||||
},
|
||||
{
|
||||
name: "Nested lists with indentation",
|
||||
input: "<ul><li>One<ul><li>Two</li></ul></li></ul>",
|
||||
// Expect the sub-element to have 4 spaces of indentation
|
||||
expected: "- One\n - Two",
|
||||
},
|
||||
{
|
||||
name: "Image support",
|
||||
input: `<img src="image.jpg" alt="alternative text">`,
|
||||
// Correct Markdown syntax for images
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Image support without alt-text",
|
||||
input: `<img src="image.jpg">`,
|
||||
// If alt is missing, square brackets remain empty
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "XSS Bypass on Links (Obfuscated HTML entities)",
|
||||
// The Go HTML parser resolves entities, so this becomes "javascript:alert(1)"
|
||||
input: `<a href="jav	ascript:alert(1)">Click here</a>`,
|
||||
// Our isSafeHref (if updated with net/url) should neutralize it to "#"
|
||||
expected: "[Click here](#)",
|
||||
},
|
||||
{
|
||||
name: "Empty link or used as anchor",
|
||||
input: `<a name="top"></a>`,
|
||||
// With no text or href, it shouldn't print anything (not even empty brackets)
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Link without href but with text (Textual anchor)",
|
||||
input: `<a id="top">Back to top</a>`,
|
||||
// Should extract only plain text, without generating a broken Markdown link like [Back to top](#) or [Back to top]()
|
||||
expected: "Back to top",
|
||||
},
|
||||
{
|
||||
name: "Badly spaced bold and italics (Edge Case)",
|
||||
input: `<b> Text </b>`,
|
||||
// In Markdown `** Text **` is often not formatted correctly. The ideal is `**Text**`
|
||||
expected: "**Text**",
|
||||
},
|
||||
{
|
||||
name: "Complex Test - Real Article",
|
||||
input: `
|
||||
<h1>Article Title</h1>
|
||||
<p>This is an <strong>introductory text</strong> with a <a href="http://link.com">link</a>.</p>
|
||||
<h2>Subtitle</h2>
|
||||
<ul>
|
||||
<li>Point one</li>
|
||||
<li>Point two</li>
|
||||
</ul>
|
||||
<script>console.log("do not show me")</script>
|
||||
`,
|
||||
// Note: The indentation of the real HTML test will generate spaces that
|
||||
// regex will clean up.
|
||||
expected: "# Article Title\n\nThis is an **introductory text** with a [link](http://link.com).\n\n## Subtitle\n\n- Point one\n- Point two",
|
||||
},
|
||||
{
|
||||
name: "Ordered list (OL)",
|
||||
input: `<ol><li>First</li><li>Second</li><li>Third</li></ol>`,
|
||||
expected: "1. First\n2. Second\n3. Third",
|
||||
},
|
||||
{
|
||||
name: "Ordered list nested in unordered list",
|
||||
input: `<ul><li>Fruits<ol><li>Apples</li><li>Pears</li></ol></li><li>Vegetables</li></ul>`,
|
||||
expected: "- Fruits\n 1. Apples\n 2. Pears\n- Vegetables",
|
||||
},
|
||||
{
|
||||
name: "Code block (pre/code)",
|
||||
input: "<pre><code>func main() {\n fmt.Println(\"hello\")\n}</code></pre>",
|
||||
expected: "```\nfunc main() {\n fmt.Println(\"hello\")\n}\n```",
|
||||
},
|
||||
{
|
||||
name: "Inline code",
|
||||
input: `<p>Use the command <code>go test ./...</code> to run the tests.</p>`,
|
||||
expected: "Use the command `go test ./...` to run the tests.",
|
||||
},
|
||||
{
|
||||
name: "Simple blockquote",
|
||||
input: `<blockquote><p>An important quote.</p></blockquote>`,
|
||||
expected: "> An important quote.",
|
||||
},
|
||||
{
|
||||
name: "Multiline blockquote",
|
||||
input: `<blockquote><p>First line of the quote.</p><p>Second line of the quote.</p></blockquote>`,
|
||||
expected: "> First line of the quote.\n>\n> Second line of the quote.",
|
||||
},
|
||||
{
|
||||
name: "Strikethrough text (del/s)",
|
||||
input: `This text is <del>deleted</del> and this is <s>crossed out</s>.`,
|
||||
expected: "This text is ~~deleted~~ and this is ~~crossed out~~.",
|
||||
},
|
||||
{
|
||||
name: "Horizontal separator (HR)",
|
||||
input: `<p>Above the line</p><hr><p>Below the line</p>`,
|
||||
expected: "Above the line\n\n---\n\nBelow the line",
|
||||
},
|
||||
{
|
||||
name: "Bold nested in link",
|
||||
input: `<a href="https://example.com"><strong>Linked bold text</strong></a>`,
|
||||
expected: "[**Linked bold text**](https://example.com)",
|
||||
},
|
||||
{
|
||||
name: "data-src Image (lazy loading)",
|
||||
input: `<img data-src="lazy.jpg" alt="Lazy image">`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Image with javascript: src blocked",
|
||||
input: `<img src="javascript:alert(1)" alt="XSS">`,
|
||||
// src is not safe, so the image is not emitted
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Link with data: href blocked",
|
||||
input: `<a href="data:text/html,<script>alert(1)</script>">Click</a>`,
|
||||
expected: "[Click](#)",
|
||||
},
|
||||
{
|
||||
name: "Deeply nested divs",
|
||||
input: `<div><div><div><div><p>Deeply nested text</p></div></div></div></div>`,
|
||||
expected: "Deeply nested text",
|
||||
},
|
||||
{
|
||||
name: "Non-consecutive headers (H1, H3, H5)",
|
||||
input: `<h1>Title</h1><h3>Subsection</h3><h5>Sub-subsection</h5>`,
|
||||
expected: "# Title\n\n### Subsection\n\n##### Sub-subsection",
|
||||
},
|
||||
{
|
||||
name: "Paragraph with mixed multiple emphasis",
|
||||
input: `<p><strong>Important:</strong> read the <strong><em>critical instructions</em></strong> <em>carefully</em>.</p>`,
|
||||
expected: "**Important:** read the ***critical instructions*** *carefully*.",
|
||||
},
|
||||
{
|
||||
name: "Article with nav and aside sections (noise to filter)",
|
||||
input: `
|
||||
<nav><a href="/home">Home</a><a href="/about-us">About us</a></nav>
|
||||
<article>
|
||||
<h2>Article title</h2>
|
||||
<p>This is the body of the article.</p>
|
||||
</article>
|
||||
<aside><p>Advertisement</p></aside>
|
||||
`,
|
||||
expected: "## Article title\n\nThis is the body of the article.",
|
||||
},
|
||||
{
|
||||
name: "Text with mixed special HTML entities",
|
||||
input: `Copyright © 2024 — All rights reserved ®`,
|
||||
expected: "Copyright © 2024 — All rights reserved ®",
|
||||
},
|
||||
{
|
||||
name: "Mailto link",
|
||||
input: `Write to us at <a href="mailto:info@example.com">info@example.com</a>`,
|
||||
expected: "Write to us at [info@example.com](mailto:info@example.com)",
|
||||
},
|
||||
{
|
||||
name: "Image inside a link (clickable figure)",
|
||||
input: `<a href="https://example.com"><img src="photo.jpg" alt="Photo"></a>`,
|
||||
// The image-link without text must not generate broken markup
|
||||
expected: "[](https://example.com)",
|
||||
},
|
||||
{
|
||||
name: "Empty content or only whitespace",
|
||||
input: ` <p> </p> <div> </div> `,
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
// Iterate over all test cases
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := HtmlToMarkdown(tt.input)
|
||||
if err != nil {
|
||||
logger.ErrorCF("tool", "Failed to parse html to markdown: %s", map[string]any{"error": err.Error()})
|
||||
}
|
||||
|
||||
if got != tt.expected {
|
||||
t.Errorf("\nTest case failed: %s\nInput: %q\nGot: %q\nExpected: %q",
|
||||
tt.name, tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user