Merge upstream/main and resolve conflicts in .env.example

This commit is contained in:
mutezebra
2026-03-08 15:32:11 +08:00
94 changed files with 6006 additions and 482 deletions
+54 -1
View File
@@ -605,7 +605,60 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
}
}
return sanitized
// Second pass: ensure every assistant message with tool_calls has matching
// tool result messages following it. This is required by strict providers
// like DeepSeek that enforce: "An assistant message with 'tool_calls' must
// be followed by tool messages responding to each 'tool_call_id'."
final := make([]providers.Message, 0, len(sanitized))
for i := 0; i < len(sanitized); i++ {
msg := sanitized[i]
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
// Collect expected tool_call IDs
expected := make(map[string]bool, len(msg.ToolCalls))
for _, tc := range msg.ToolCalls {
expected[tc.ID] = false
}
// Check following messages for matching tool results
toolMsgCount := 0
for j := i + 1; j < len(sanitized); j++ {
if sanitized[j].Role != "tool" {
break
}
toolMsgCount++
if _, exists := expected[sanitized[j].ToolCallID]; exists {
expected[sanitized[j].ToolCallID] = true
}
}
// If any tool_call_id is missing, drop this assistant message and its partial tool messages
allFound := true
for toolCallID, found := range expected {
if !found {
allFound = false
logger.DebugCF(
"agent",
"Dropping assistant message with incomplete tool results",
map[string]any{
"missing_tool_call_id": toolCallID,
"expected_count": len(expected),
"found_count": toolMsgCount,
},
)
break
}
}
if !allFound {
// Skip this assistant message and its tool messages
i += toolMsgCount
continue
}
}
final = append(final, msg)
}
return final
}
func (cb *ContextBuilder) AddToolResult(
+74
View File
@@ -207,3 +207,77 @@ func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) {
}
}
}
// TestSanitizeHistoryForProvider_IncompleteToolResults tests the forward validation
// that ensures assistant messages with tool_calls have ALL matching tool results.
// This fixes the DeepSeek error: "An assistant message with 'tool_calls' must be
// followed by tool messages responding to each 'tool_call_id'."
func TestSanitizeHistoryForProvider_IncompleteToolResults(t *testing.T) {
// Assistant expects tool results for both A and B, but only A is present
history := []providers.Message{
msg("user", "do two things"),
assistantWithTools("A", "B"),
toolResult("A"),
// toolResult("B") is missing - this would cause DeepSeek to fail
msg("user", "next question"),
msg("assistant", "answer"),
}
result := sanitizeHistoryForProvider(history)
// The assistant message with incomplete tool results should be dropped,
// along with its partial tool result. The remaining messages are:
// user ("do two things"), user ("next question"), assistant ("answer")
if len(result) != 3 {
t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "user", "assistant")
}
// TestSanitizeHistoryForProvider_MissingAllToolResults tests the case where
// an assistant message has tool_calls but no tool results follow at all.
func TestSanitizeHistoryForProvider_MissingAllToolResults(t *testing.T) {
history := []providers.Message{
msg("user", "do something"),
assistantWithTools("A"),
// No tool results at all
msg("user", "hello"),
msg("assistant", "hi"),
}
result := sanitizeHistoryForProvider(history)
// The assistant message with no tool results should be dropped.
// Remaining: user ("do something"), user ("hello"), assistant ("hi")
if len(result) != 3 {
t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "user", "assistant")
}
// TestSanitizeHistoryForProvider_PartialToolResultsInMiddle tests that
// incomplete tool results in the middle of a conversation are properly handled.
func TestSanitizeHistoryForProvider_PartialToolResultsInMiddle(t *testing.T) {
history := []providers.Message{
msg("user", "first"),
assistantWithTools("A"),
toolResult("A"),
msg("assistant", "done"),
msg("user", "second"),
assistantWithTools("B", "C"),
toolResult("B"),
// toolResult("C") is missing
msg("user", "third"),
assistantWithTools("D"),
toolResult("D"),
msg("assistant", "all done"),
}
result := sanitizeHistoryForProvider(history)
// First round is complete (user, assistant+tools, tool, assistant),
// second round is incomplete and dropped (assistant+tools, partial tool),
// third round is complete (user, assistant+tools, tool, assistant).
// Remaining: user, assistant, tool, assistant, user, user, assistant, tool, assistant
if len(result) != 9 {
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "user", "assistant", "tool", "assistant")
}
+29
View File
@@ -37,6 +37,14 @@ type AgentInstance struct {
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
// Router is non-nil when model routing is configured and the light model
// was successfully resolved. It scores each incoming message and decides
// whether to route to LightCandidates or stay with Candidates.
Router *routing.Router
// LightCandidates holds the resolved provider candidates for the light model.
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
LightCandidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
@@ -180,6 +188,25 @@ func NewAgentInstance(
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
var lightCandidates []providers.FallbackCandidate
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
if len(resolved) > 0 {
router = routing.New(routing.RouterConfig{
LightModel: rc.LightModel,
Threshold: rc.Threshold,
})
lightCandidates = resolved
} else {
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
rc.LightModel, agentID)
}
}
return &AgentInstance{
ID: agentID,
Name: agentName,
@@ -200,6 +227,8 @@ func NewAgentInstance(
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
Router: router,
LightCandidates: lightCandidates,
}
}
+246 -117
View File
@@ -21,6 +21,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -46,6 +47,7 @@ type AgentLoop struct {
channelManager *channels.Manager
mediaStore media.MediaStore
transcriber voice.Transcriber
cmdRegistry *commands.Registry
}
// processOptions configures how a message is processed
@@ -61,7 +63,15 @@ type processOptions struct {
NoHistory bool // If true, don't load session history (for heartbeat)
}
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
const (
defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
sessionKeyAgentPrefix = "agent:"
metadataKeyAccountID = "account_id"
metadataKeyGuildID = "guild_id"
metadataKeyTeamID = "team_id"
metadataKeyParentPeerKind = "parent_peer_kind"
metadataKeyParentPeerID = "parent_peer_id"
)
func NewAgentLoop(
cfg *config.Config,
@@ -84,14 +94,17 @@ func NewAgentLoop(
stateManager = state.NewManager(defaultAgent.Workspace)
}
return &AgentLoop{
al := &AgentLoop{
bus: msgBus,
cfg: cfg,
registry: registry,
state: stateManager,
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
}
return al
}
// registerSharedTools registers tools that are shared across all agents (web, message, spawn).
@@ -170,6 +183,17 @@ func registerSharedTools(
agent.Tools.Register(messageTool)
}
// Send file tool (outbound media via MediaStore — store injected later by SetMediaStore)
if cfg.Tools.IsToolEnabled("send_file") {
sendFileTool := tools.NewSendFileTool(
agent.Workspace,
cfg.Agents.Defaults.RestrictToWorkspace,
cfg.Agents.Defaults.GetMaxMediaSize(),
nil,
)
agent.Tools.Register(sendFileTool)
}
// Skill discovery and installation tools
skills_enabled := cfg.Tools.IsToolEnabled("skills")
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
@@ -196,7 +220,7 @@ func registerSharedTools(
// Spawn tool with allowlist checker
if cfg.Tools.IsToolEnabled("spawn") {
if cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
@@ -371,6 +395,13 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
// SetMediaStore injects a MediaStore for media lifecycle management.
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
al.mediaStore = s
// Propagate store to send_file tools in all agents.
al.registry.ForEachTool("send_file", func(t tools.Tool) {
if sf, ok := t.(*tools.SendFileTool); ok {
sf.SetMediaStore(s)
}
})
}
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
@@ -549,27 +580,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return al.processSystemMessage(ctx, msg)
}
// Check for commands
if response, handled := al.handleCommand(ctx, msg); handled {
route, agent, routeErr := al.resolveMessageRoute(msg)
// Commands are checked before requiring a successful route.
// Global commands (/help, /show, /switch) work even when routing fails;
// context-dependent commands check their own Runtime fields and report
// "unavailable" when the required capability is nil.
if response, handled := al.handleCommand(ctx, msg, agent); handled {
return response, nil
}
// Route to determine agent and session key
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
AccountID: msg.Metadata["account_id"],
Peer: extractPeer(msg),
ParentPeer: extractParentPeer(msg),
GuildID: msg.Metadata["guild_id"],
TeamID: msg.Metadata["team_id"],
})
agent, ok := al.registry.GetAgent(route.AgentID)
if !ok {
agent = al.registry.GetDefaultAgent()
}
if agent == nil {
return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
if routeErr != nil {
return "", routeErr
}
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
@@ -579,17 +601,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
}
}
// Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron)
sessionKey := route.SessionKey
if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") {
sessionKey = msg.SessionKey
}
// Resolve session key from route, while preserving explicit agent-scoped keys.
scopeKey := resolveScopeKey(route, msg.SessionKey)
sessionKey := scopeKey
logger.InfoCF("agent", "Routed message",
map[string]any{
"agent_id": agent.ID,
"session_key": sessionKey,
"matched_by": route.MatchedBy,
"agent_id": agent.ID,
"scope_key": scopeKey,
"session_key": sessionKey,
"matched_by": route.MatchedBy,
"route_agent": route.AgentID,
"route_channel": route.Channel,
})
return al.runAgentLoop(ctx, agent, processOptions{
@@ -604,6 +627,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
})
}
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
AccountID: inboundMetadata(msg, metadataKeyAccountID),
Peer: extractPeer(msg),
ParentPeer: extractParentPeer(msg),
GuildID: inboundMetadata(msg, metadataKeyGuildID),
TeamID: inboundMetadata(msg, metadataKeyTeamID),
})
agent, ok := al.registry.GetAgent(route.AgentID)
if !ok {
agent = al.registry.GetDefaultAgent()
}
if agent == nil {
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
}
return route, agent, nil
}
func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string {
if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) {
return msgSessionKey
}
return route.SessionKey
}
func (al *AgentLoop) processSystemMessage(
ctx context.Context,
msg bus.InboundMessage,
@@ -675,9 +726,8 @@ func (al *AgentLoop) runAgentLoop(
agent *AgentInstance,
opts processOptions,
) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
// 0. Record last channel for heartbeat notifications (skip internal channels and cli)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
@@ -824,6 +874,12 @@ func (al *AgentLoop) runLLMIteration(
iteration := 0
var finalContent string
// Determine effective model tier for this conversation turn.
// selectCandidates evaluates routing once and the decision is sticky for
// all tool-follow-up iterations within the same turn so that a multi-step
// tool chain doesn't switch models mid-way through.
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
for iteration < agent.MaxIterations {
iteration++
@@ -842,7 +898,7 @@ func (al *AgentLoop) runLLMIteration(
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"model": agent.Model,
"model": activeModel,
"messages_count": len(messages),
"tools_count": len(providerToolDefs),
"max_tokens": agent.MaxTokens,
@@ -858,7 +914,7 @@ func (al *AgentLoop) runLLMIteration(
"tools_json": formatToolsForLog(providerToolDefs),
})
// Call LLM with fallback chain if candidates are configured.
// Call LLM with fallback chain if multiple candidates are configured.
var response *providers.LLMResponse
var err error
@@ -879,10 +935,10 @@ func (al *AgentLoop) runLLMIteration(
}
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
ctx,
agent.Candidates,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
},
@@ -900,7 +956,7 @@ func (al *AgentLoop) runLLMIteration(
}
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts)
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
}
// Retry loop for context/token errors
@@ -999,9 +1055,12 @@ func (al *AgentLoop) runLLMIteration(
"target_channel": al.targetReasoningChannelID(opts.Channel),
"channel": opts.Channel,
})
// Check if no tool calls - we're done
// Check if no tool calls - then check reasoning content if any
if len(response.ToolCalls) == 0 {
finalContent = response.Content
if finalContent == "" && response.ReasoningContent != "" {
finalContent = response.ReasoningContent
}
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
map[string]any{
"agent_id": agent.ID,
@@ -1087,15 +1146,47 @@ func (al *AgentLoop) runLLMIteration(
"iteration": iteration,
})
// Create async callback for tools that implement AsyncExecutor
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Create async callback for tools that implement AsyncExecutor.
// When the background work completes, this publishes the result
// as an inbound system message so processSystemMessage routes it
// back to the user via the normal agent loop.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
// Send ForUser content directly to the user (immediate feedback),
// mirroring the synchronous tool execution path.
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]any{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: result.ForUser,
})
}
// Determine content for the agent loop (ForLLM or error).
content := result.ForLLM
if content == "" && result.Err != nil {
content = result.Err.Error()
}
if content == "" {
return
}
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
"tool": tc.Name,
"content_len": len(content),
"channel": opts.Channel,
})
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("async:%s", tc.Name),
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
Content: content,
})
}
toolResult := agent.Tools.ExecuteWithContext(
@@ -1128,7 +1219,7 @@ func (al *AgentLoop) runLLMIteration(
}
// If tool returned media refs, publish them as outbound media
if len(r.result.Media) > 0 && opts.SendResponse {
if len(r.result.Media) > 0 {
parts := make([]bus.MediaPart, 0, len(r.result.Media))
for _, ref := range r.result.Media {
part := bus.MediaPart{Ref: ref}
@@ -1169,6 +1260,44 @@ func (al *AgentLoop) runLLMIteration(
return finalContent, iteration, nil
}
// selectCandidates returns the model candidates and resolved model name to use
// for a conversation turn. When model routing is configured and the incoming
// message scores below the complexity threshold, it returns the light model
// candidates instead of the primary ones.
//
// The returned (candidates, model) pair is used for all LLM calls within one
// turn — tool follow-up iterations use the same tier as the initial call so
// that a multi-step tool chain doesn't switch models mid-way.
func (al *AgentLoop) selectCandidates(
agent *AgentInstance,
userMsg string,
history []providers.Message,
) (candidates []providers.FallbackCandidate, model string) {
if agent.Router == nil || len(agent.LightCandidates) == 0 {
return agent.Candidates, agent.Model
}
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
if !usedLight {
logger.DebugCF("agent", "Model routing: primary model selected",
map[string]any{
"agent_id": agent.ID,
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.Candidates, agent.Model
}
logger.InfoCF("agent", "Model routing: light model selected",
map[string]any{
"agent_id": agent.ID,
"light_model": agent.Router.LightModel(),
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.LightCandidates, agent.Router.LightModel()
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
@@ -1460,94 +1589,87 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
return totalChars * 2 / 5
}
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
content := strings.TrimSpace(msg.Content)
if !strings.HasPrefix(content, "/") {
func (al *AgentLoop) handleCommand(
ctx context.Context,
msg bus.InboundMessage,
agent *AgentInstance,
) (string, bool) {
if !commands.HasCommandPrefix(msg.Content) {
return "", false
}
parts := strings.Fields(content)
if len(parts) == 0 {
if al.cmdRegistry == nil {
return "", false
}
cmd := parts[0]
args := parts[1:]
rt := al.buildCommandsRuntime(agent)
executor := commands.NewExecutor(al.cmdRegistry, rt)
switch cmd {
case "/show":
if len(args) < 1 {
return "Usage: /show [model|channel|agents]", true
}
switch args[0] {
case "model":
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
return "No default agent configured", true
}
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
case "channel":
return fmt.Sprintf("Current channel: %s", msg.Channel), true
case "agents":
agentIDs := al.registry.ListAgentIDs()
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
default:
return fmt.Sprintf("Unknown show target: %s", args[0]), true
}
var commandReply string
result := executor.Execute(ctx, commands.Request{
Channel: msg.Channel,
ChatID: msg.ChatID,
SenderID: msg.SenderID,
Text: msg.Content,
Reply: func(text string) error {
commandReply = text
return nil
},
})
case "/list":
if len(args) < 1 {
return "Usage: /list [models|channels|agents]", true
switch result.Outcome {
case commands.OutcomeHandled:
if result.Err != nil {
return mapCommandError(result), true
}
switch args[0] {
case "models":
return "Available models: configured in config.json per agent", true
case "channels":
if commandReply != "" {
return commandReply, true
}
return "", true
default: // OutcomePassthrough — let the message fall through to LLM
return "", false
}
}
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime {
rt := &commands.Runtime{
Config: al.cfg,
ListAgentIDs: al.registry.ListAgentIDs,
ListDefinitions: al.cmdRegistry.Definitions,
GetEnabledChannels: func() []string {
if al.channelManager == nil {
return "Channel manager not initialized", true
return nil
}
channels := al.channelManager.GetEnabledChannels()
if len(channels) == 0 {
return "No channels enabled", true
}
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
case "agents":
agentIDs := al.registry.ListAgentIDs()
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
default:
return fmt.Sprintf("Unknown list target: %s", args[0]), true
}
case "/switch":
if len(args) < 3 || args[1] != "to" {
return "Usage: /switch [model|channel] to <name>", true
}
target := args[0]
value := args[2]
switch target {
case "model":
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
return "No default agent configured", true
}
oldModel := defaultAgent.Model
defaultAgent.Model = value
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
case "channel":
return al.channelManager.GetEnabledChannels()
},
SwitchChannel: func(value string) error {
if al.channelManager == nil {
return "Channel manager not initialized", true
return fmt.Errorf("channel manager not initialized")
}
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
return fmt.Errorf("channel '%s' not found or not enabled", value)
}
return fmt.Sprintf("Switched target channel to %s", value), true
default:
return fmt.Sprintf("Unknown switch target: %s", target), true
return nil
},
}
if agent != nil {
rt.GetModelInfo = func() (string, string) {
return agent.Model, al.cfg.Agents.Defaults.Provider
}
rt.SwitchModel = func(value string) (string, error) {
oldModel := agent.Model
agent.Model = value
return oldModel, nil
}
}
return rt
}
return "", false
func mapCommandError(result commands.ExecuteResult) string {
if result.Command == "" {
return fmt.Sprintf("Failed to execute command: %v", result.Err)
}
return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err)
}
// extractPeer extracts the routing peer from the inbound message's structured Peer field.
@@ -1566,10 +1688,17 @@ func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID}
}
func inboundMetadata(msg bus.InboundMessage, key string) string {
if msg.Metadata == nil {
return ""
}
return msg.Metadata[key]
}
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
parentKind := msg.Metadata["parent_peer_kind"]
parentID := msg.Metadata["parent_peer_id"]
parentKind := inboundMetadata(msg, metadataKeyParentPeerKind)
parentID := inboundMetadata(msg, metadataKeyParentPeerID)
if parentKind == "" || parentID == "" {
return nil
}
+216
View File
@@ -15,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -318,6 +319,29 @@ func (m *simpleMockProvider) GetDefaultModel() string {
return "mock-model"
}
type countingMockProvider struct {
response string
calls int
}
func (m *countingMockProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *countingMockProvider) GetDefaultModel() string {
return "counting-mock-model"
}
// mockCustomTool is a simple mock tool for registration testing
type mockCustomTool struct{}
@@ -359,6 +383,198 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms
const responseTimeout = 3 * time.Second
func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "ok"}
al := NewAgentLoop(cfg, msgBus, provider)
msg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hello",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
Peer: extractPeer(msg),
})
sessionKey := route.SessionKey
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
helper := testHelper{al: al}
_ = helper.executeAndGetResponse(t, context.Background(), msg)
history := defaultAgent.Sessions.GetHistory(sessionKey)
if len(history) != 2 {
t.Fatalf("expected session history len=2, got %d", len(history))
}
if history[0].Role != "user" || history[0].Content != "hello" {
t.Fatalf("unexpected first message in session: %+v", history[0])
}
}
func TestProcessMessage_CommandOutcomes(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
Session: config.SessionConfig{
DMScope: "per-channel-peer",
},
}
msgBus := bus.NewMessageBus()
provider := &countingMockProvider{response: "LLM reply"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
baseMsg := bus.InboundMessage{
Channel: "whatsapp",
SenderID: "user1",
ChatID: "chat1",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/show channel",
Peer: baseMsg.Peer,
})
if showResp != "Current Channel: whatsapp" {
t.Fatalf("unexpected /show reply: %q", showResp)
}
if provider.calls != 0 {
t.Fatalf("LLM should not be called for handled command, calls=%d", provider.calls)
}
fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/foo",
Peer: baseMsg.Peer,
})
if fooResp != "LLM reply" {
t.Fatalf("unexpected /foo reply: %q", fooResp)
}
if provider.calls != 1 {
t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls)
}
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/new",
Peer: baseMsg.Peer,
})
if newResp != "LLM reply" {
t.Fatalf("unexpected /new reply: %q", newResp)
}
if provider.calls != 2 {
t.Fatalf("LLM should be called for passthrough /new command, calls=%d", provider.calls)
}
}
func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Provider: "openai",
Model: "before-switch",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &countingMockProvider{response: "LLM reply"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to after-switch",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
t.Fatalf("unexpected /switch reply: %q", switchResp)
}
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/show model",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
}
if provider.calls != 0 {
t.Fatalf("LLM should not be called for /switch and /show, calls=%d", provider.calls)
}
}
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
+14
View File
@@ -7,6 +7,7 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
// AgentRegistry manages multiple agent instances and routes messages to them.
@@ -100,6 +101,19 @@ func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bo
return false
}
// ForEachTool calls fn for every tool registered under the given name
// across all agents. This is useful for propagating dependencies (e.g.
// MediaStore) to tools after registry construction.
func (r *AgentRegistry) ForEachTool(name string, fn func(tools.Tool)) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, agent := range r.agents {
if t, ok := agent.Tools.Get(name); ok {
fn(t)
}
}
}
// GetDefaultAgent returns the default agent instance.
func (r *AgentRegistry) GetDefaultAgent() *AgentInstance {
r.mu.RLock()
+71
View File
@@ -0,0 +1,71 @@
package auth
import (
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
const (
anthropicBetaHeader = "oauth-2025-04-20"
anthropicAPIVersion = "2023-06-01"
)
// anthropicUsageURL is the endpoint for fetching OAuth usage stats.
// It is a var (not const) to allow overriding in tests.
var anthropicUsageURL = "https://api.anthropic.com/api/oauth/usage"
func setAnthropicUsageURL(url string) { anthropicUsageURL = url }
type AnthropicUsage struct {
FiveHourUtilization float64
SevenDayUtilization float64
}
func FetchAnthropicUsage(token string) (*AnthropicUsage, error) {
req, err := http.NewRequest("GET", anthropicUsageURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Anthropic-Version", anthropicAPIVersion)
req.Header.Set("Anthropic-Beta", anthropicBetaHeader)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading usage response: %w", err)
}
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusForbidden {
return nil, fmt.Errorf("insufficient scope: usage endpoint requires oauth scope")
}
return nil, fmt.Errorf("usage request failed (%d): %s", resp.StatusCode, string(body))
}
var result struct {
FiveHour struct {
Utilization float64 `json:"utilization"`
} `json:"five_hour"`
SevenDay struct {
Utilization float64 `json:"utilization"`
} `json:"seven_day"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parsing usage response: %w", err)
}
return &AnthropicUsage{
FiveHourUtilization: result.FiveHour.Utilization,
SevenDayUtilization: result.SevenDay.Utilization,
}, nil
}
+98
View File
@@ -0,0 +1,98 @@
package auth
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestFetchAnthropicUsage_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
t.Errorf("Authorization = %q, want %q", got, "Bearer test-token")
}
if got := r.Header.Get("Anthropic-Beta"); got != anthropicBetaHeader {
t.Errorf("Anthropic-Beta = %q, want %q", got, anthropicBetaHeader)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"five_hour":{"utilization":0.42},"seven_day":{"utilization":0.85}}`))
}))
defer srv.Close()
// Temporarily override the URL by using the test server
origURL := anthropicUsageURL
defer func() { setAnthropicUsageURL(origURL) }()
setAnthropicUsageURL(srv.URL)
usage, err := FetchAnthropicUsage("test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if usage.FiveHourUtilization != 0.42 {
t.Errorf("FiveHourUtilization = %v, want 0.42", usage.FiveHourUtilization)
}
if usage.SevenDayUtilization != 0.85 {
t.Errorf("SevenDayUtilization = %v, want 0.85", usage.SevenDayUtilization)
}
}
func TestFetchAnthropicUsage_Forbidden(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error":"forbidden"}`))
}))
defer srv.Close()
origURL := anthropicUsageURL
defer func() { setAnthropicUsageURL(origURL) }()
setAnthropicUsageURL(srv.URL)
_, err := FetchAnthropicUsage("test-token")
if err == nil {
t.Fatal("expected error for 403, got nil")
}
if !strings.Contains(err.Error(), "insufficient scope") {
t.Errorf("expected 'insufficient scope' error, got %q", err.Error())
}
}
func TestFetchAnthropicUsage_ServerError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`internal error`))
}))
defer srv.Close()
origURL := anthropicUsageURL
defer func() { setAnthropicUsageURL(origURL) }()
setAnthropicUsageURL(srv.URL)
_, err := FetchAnthropicUsage("test-token")
if err == nil {
t.Fatal("expected error for 500, got nil")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("expected error containing '500', got %q", err.Error())
}
}
func TestFetchAnthropicUsage_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`not json`))
}))
defer srv.Close()
origURL := anthropicUsageURL
defer func() { setAnthropicUsageURL(origURL) }()
setAnthropicUsageURL(srv.URL)
_, err := FetchAnthropicUsage("test-token")
if err == nil {
t.Fatal("expected error for malformed JSON, got nil")
}
if !strings.Contains(err.Error(), "parsing usage response") {
t.Errorf("expected 'parsing usage response' error, got %q", err.Error())
}
}
+29
View File
@@ -31,6 +31,35 @@ func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
}, nil
}
func LoginSetupToken(r io.Reader) (*AuthCredential, error) {
fmt.Println("Paste your setup token from `claude setup-token`:")
fmt.Print("> ")
scanner := bufio.NewScanner(r)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading token: %w", err)
}
return nil, fmt.Errorf("no input received")
}
token := strings.TrimSpace(scanner.Text())
if !strings.HasPrefix(token, "sk-ant-oat01-") {
return nil, fmt.Errorf("invalid setup token: expected prefix sk-ant-oat01-")
}
if len(token) < 80 {
return nil, fmt.Errorf("invalid setup token: too short (expected at least 80 characters)")
}
return &AuthCredential{
AccessToken: token,
Provider: "anthropic",
AuthMethod: "oauth",
}, nil
}
func providerDisplayName(provider string) string {
switch provider {
case "anthropic":
+61
View File
@@ -0,0 +1,61 @@
package auth
import (
"strings"
"testing"
)
func TestLoginSetupToken(t *testing.T) {
// A valid token: correct prefix + at least 80 chars
validToken := "sk-ant-oat01-" + strings.Repeat("a", 80)
tests := []struct {
name string
input string
wantErr string
}{
{"valid token", validToken, ""},
{"empty input", "", "expected prefix sk-ant-oat01-"},
{"wrong prefix", "sk-ant-api-" + strings.Repeat("a", 80), "expected prefix sk-ant-oat01-"},
{"too short", "sk-ant-oat01-short", "too short"},
{"whitespace only", " ", "expected prefix sk-ant-oat01-"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := strings.NewReader(tt.input + "\n")
cred, err := LoginSetupToken(r)
if tt.wantErr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErr)
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error())
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cred.AccessToken != validToken {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, validToken)
}
if cred.Provider != "anthropic" {
t.Errorf("Provider = %q, want %q", cred.Provider, "anthropic")
}
if cred.AuthMethod != "oauth" {
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
}
})
}
}
func TestLoginSetupToken_EmptyReader(t *testing.T) {
r := strings.NewReader("")
_, err := LoginSetupToken(r)
if err == nil {
t.Fatal("expected error for empty reader, got nil")
}
}
+70
View File
@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"os"
"regexp"
"strings"
"sync"
"time"
@@ -26,6 +27,12 @@ const (
sendTimeout = 10 * time.Second
)
var (
// Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call)
channelRefRe = regexp.MustCompile(`<#(\d+)>`)
msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`)
)
type DiscordChannel struct {
*channels.BaseChannel
session *discordgo.Session
@@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = c.stripBotMention(content)
}
// Resolve Discord refs in main content before concatenation to avoid
// double-expanding links that appear in the referenced message.
content = c.resolveDiscordRefs(s, content, m.GuildID)
// Prepend referenced (quoted) message content if this is a reply
if m.MessageReference != nil && m.ReferencedMessage != nil {
refContent := m.ReferencedMessage.Content
if refContent != "" {
refAuthor := "unknown"
if m.ReferencedMessage.Author != nil {
refAuthor = m.ReferencedMessage.Author.Username
}
refContent = c.resolveDiscordRefs(s, refContent, m.GuildID)
content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s",
refAuthor, refContent, content)
}
}
senderID := m.Author.ID
mediaPaths := make([]string, 0, len(m.Attachments))
@@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
return nil
}
// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and
// expands Discord message links to show the linked message content.
// Only links pointing to the same guild are expanded to prevent cross-guild leakage.
func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string {
// 1. Resolve channel references: <#id> → #channel-name
text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string {
parts := channelRefRe.FindStringSubmatch(match)
if len(parts) < 2 {
return match
}
// Prefer session state cache to avoid API calls
if ch, err := s.State.Channel(parts[1]); err == nil {
return "#" + ch.Name
}
if ch, err := s.Channel(parts[1]); err == nil {
return "#" + ch.Name
}
return match
})
// 2. Expand Discord message links (max 3, same guild only)
matches := msgLinkRe.FindAllStringSubmatch(text, 3)
for _, m := range matches {
if len(m) < 4 {
continue
}
linkGuildID, channelID, messageID := m[1], m[2], m[3]
// Security: only expand links from the same guild
if linkGuildID != guildID {
continue
}
msg, err := s.ChannelMessage(channelID, messageID)
if err != nil || msg == nil || msg.Content == "" {
continue
}
author := "unknown"
if msg.Author != nil {
author = msg.Author.Username
}
text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content)
}
return text
}
// stripBotMention removes the bot mention from the message content.
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
func (c *DiscordChannel) stripBotMention(text string) string {
@@ -0,0 +1,98 @@
package discord
import (
"testing"
)
func TestChannelRefRegex(t *testing.T) {
tests := []struct {
name string
input string
wantID string
wantOK bool
}{
{"basic channel ref", "<#123456789>", "123456789", true},
{"long id", "<#9876543210123456>", "9876543210123456", true},
{"no match plain text", "hello world", "", false},
{"no match partial", "<#>", "", false},
{"no match letters", "<#abc>", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := channelRefRe.FindStringSubmatch(tt.input)
if tt.wantOK {
if len(matches) < 2 || matches[1] != tt.wantID {
t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID)
}
} else {
if len(matches) >= 2 {
t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches)
}
}
})
}
}
func TestMsgLinkRegex(t *testing.T) {
tests := []struct {
name string
input string
wantGuild string
wantChan string
wantMsg string
wantOK bool
}{
{
"discord.com link",
"https://discord.com/channels/111/222/333",
"111", "222", "333", true,
},
{
"discordapp.com link",
"https://discordapp.com/channels/111/222/333",
"111", "222", "333", true,
},
{
"real world ids",
"check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please",
"9000000000000001", "9000000000000002", "9000000000000003", true,
},
{"no match http", "http://discord.com/channels/1/2/3", "", "", "", false},
{"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false},
{"no match plain text", "hello world", "", "", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := msgLinkRe.FindStringSubmatch(tt.input)
if tt.wantOK {
if len(matches) < 4 {
t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s",
tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg)
}
if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg {
t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s",
tt.input, matches[1], matches[2], matches[3],
tt.wantGuild, tt.wantChan, tt.wantMsg)
}
} else {
if len(matches) >= 4 {
t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches)
}
}
})
}
}
func TestMsgLinkRegex_MultipleMatches(t *testing.T) {
input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12"
matches := msgLinkRe.FindAllStringSubmatch(input, 3)
if len(matches) != 3 {
t.Fatalf("expected 3 matches (capped), got %d", len(matches))
}
// Verify the 3rd match is 7/8/9 (not 10/11/12)
if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" {
t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2])
}
}
+12 -1
View File
@@ -1,6 +1,10 @@
package channels
import "context"
import (
"context"
"github.com/sipeed/picoclaw/pkg/commands"
)
// TypingCapable — channels that can show a typing/thinking indicator.
// StartTyping begins the indicator and returns a stop function.
@@ -39,3 +43,10 @@ type PlaceholderRecorder interface {
RecordTypingStop(channel, chatID string, stop func())
RecordReactionUndo(channel, chatID string, undo func())
}
// CommandRegistrarCapable is implemented by channels that can register
// command menus with their upstream platform (e.g. Telegram BotCommand).
// Channels that do not support platform-level command menus can ignore it.
type CommandRegistrarCapable interface {
RegisterCommands(ctx context.Context, defs []commands.Definition) error
}
+16
View File
@@ -0,0 +1,16 @@
package channels
import (
"context"
"testing"
"github.com/sipeed/picoclaw/pkg/commands"
)
type mockRegistrar struct{}
func (mockRegistrar) RegisterCommands(context.Context, []commands.Definition) error { return nil }
func TestCommandRegistrarCapable_Compiles(t *testing.T) {
var _ CommandRegistrarCapable = mockRegistrar{}
}
+154
View File
@@ -0,0 +1,154 @@
package irc
import (
"fmt"
"strings"
"time"
"unicode"
"github.com/ergochat/irc-go/ircevent"
"github.com/ergochat/irc-go/ircmsg"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
)
// onConnect is called after a successful connection (and on reconnect).
func (c *IRCChannel) onConnect(conn *ircevent.Connection) {
// NickServ auth (only if SASL is not configured)
if c.config.NickServPassword != "" && c.config.SASLUser == "" {
conn.Privmsg("NickServ", "IDENTIFY "+c.config.NickServPassword)
}
// Join configured channels
for _, ch := range c.config.Channels {
conn.Join(ch)
logger.InfoCF("irc", "Joined IRC channel", map[string]any{
"channel": ch,
})
}
}
// onPrivmsg handles incoming PRIVMSG events.
func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
if len(e.Params) < 2 {
return
}
nick := e.Nick()
currentNick := conn.CurrentNick()
// Ignore own messages
if strings.EqualFold(nick, currentNick) {
return
}
target := e.Params[0] // channel name or bot's nick
content := e.Params[1] // message text
// Determine if this is a DM or channel message
isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&")
var chatID string
var peer bus.Peer
if isDM {
chatID = nick
peer = bus.Peer{Kind: "direct", ID: nick}
} else {
chatID = target
peer = bus.Peer{Kind: "group", ID: target}
}
sender := bus.SenderInfo{
Platform: "irc",
PlatformID: nick,
CanonicalID: identity.BuildCanonicalID("irc", nick),
Username: nick,
DisplayName: nick,
}
if !c.IsAllowedSender(sender) {
return
}
// For channel messages, check group trigger (mention detection)
if !isDM {
isMentioned := isBotMentioned(content, currentNick)
if isMentioned {
content = stripBotMention(content, currentNick)
}
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
if !respond {
return
}
content = cleaned
}
if strings.TrimSpace(content) == "" {
return
}
messageID := fmt.Sprintf("%s-%d", nick, time.Now().UnixNano())
metadata := map[string]string{
"platform": "irc",
"server": c.config.Server,
}
if !isDM {
metadata["channel"] = target
}
c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender)
}
// nickMentionedAt returns the byte index where botNick is mentioned in content
// with word-boundary checks, or -1 if not found. Also checks for "nick:" /
// "nick," prefix convention.
func nickMentionedAt(content, botNick string) int {
lower := strings.ToLower(content)
lowerNick := strings.ToLower(botNick)
// "nick:" or "nick," at start (most common IRC convention)
if strings.HasPrefix(lower, lowerNick+":") || strings.HasPrefix(lower, lowerNick+",") {
return 0
}
// Word-boundary match anywhere in the message
idx := strings.Index(lower, lowerNick)
if idx < 0 {
return -1
}
runes := []rune(lower)
nickRunes := []rune(lowerNick)
endIdx := idx + len(string(nickRunes))
before := idx == 0 || !unicode.IsLetter(runes[idx-1]) && !unicode.IsDigit(runes[idx-1])
after := endIdx >= len(lower) || !unicode.IsLetter(rune(lower[endIdx])) && !unicode.IsDigit(rune(lower[endIdx]))
if before && after {
return idx
}
return -1
}
// isBotMentioned checks if the bot's nick appears in the message.
func isBotMentioned(content, botNick string) bool {
return nickMentionedAt(content, botNick) >= 0
}
// stripBotMention removes "nick: " or "nick, " prefix from content.
func stripBotMention(content, botNick string) string {
idx := nickMentionedAt(content, botNick)
if idx != 0 {
return content
}
lowerNick := strings.ToLower(botNick)
lower := strings.ToLower(content)
for _, sep := range []string{":", ","} {
prefix := lowerNick + sep
if strings.HasPrefix(lower, prefix) {
return strings.TrimSpace(content[len(prefix):])
}
}
return content
}
+16
View File
@@ -0,0 +1,16 @@
package irc
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func init() {
channels.RegisterFactory("irc", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
if !cfg.Channels.IRC.Enabled {
return nil, nil
}
return NewIRCChannel(cfg.Channels.IRC, b)
})
}
+194
View File
@@ -0,0 +1,194 @@
package irc
import (
"context"
"crypto/tls"
"fmt"
"strings"
"github.com/ergochat/irc-go/ircevent"
"github.com/ergochat/irc-go/ircmsg"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// IRCChannel implements the Channel interface for IRC servers.
type IRCChannel struct {
*channels.BaseChannel
config config.IRCConfig
conn *ircevent.Connection
ctx context.Context
cancel context.CancelFunc
}
// NewIRCChannel creates a new IRC channel.
func NewIRCChannel(cfg config.IRCConfig, messageBus *bus.MessageBus) (*IRCChannel, error) {
if cfg.Server == "" {
return nil, fmt.Errorf("irc server is required")
}
if cfg.Nick == "" {
return nil, fmt.Errorf("irc nick is required")
}
base := channels.NewBaseChannel("irc", cfg, messageBus, cfg.AllowFrom,
channels.WithMaxMessageLength(400),
channels.WithGroupTrigger(cfg.GroupTrigger),
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
return &IRCChannel{
BaseChannel: base,
config: cfg,
}, nil
}
// Start connects to the IRC server and begins listening.
func (c *IRCChannel) Start(ctx context.Context) error {
logger.InfoC("irc", "Starting IRC channel")
c.ctx, c.cancel = context.WithCancel(ctx)
user := c.config.User
if user == "" {
user = c.config.Nick
}
realName := c.config.RealName
if realName == "" {
realName = c.config.Nick
}
caps := []string(c.config.RequestCaps)
if len(caps) == 0 {
caps = []string{"server-time", "message-tags"}
}
conn := &ircevent.Connection{
Server: c.config.Server,
Nick: c.config.Nick,
User: user,
RealName: realName,
Password: c.config.Password,
UseTLS: c.config.TLS,
RequestCaps: caps,
QuitMessage: "Goodbye",
Debug: false,
Log: nil,
}
if c.config.TLS {
conn.TLSConfig = &tls.Config{
ServerName: extractHost(c.config.Server),
}
}
// SASL auth (takes priority over NickServ)
if c.config.SASLUser != "" && c.config.SASLPassword != "" {
conn.SASLLogin = c.config.SASLUser
conn.SASLPassword = c.config.SASLPassword
}
// Register event handlers
conn.AddConnectCallback(func(e ircmsg.Message) {
c.onConnect(conn)
})
conn.AddCallback("PRIVMSG", func(e ircmsg.Message) {
c.onPrivmsg(conn, e)
})
if err := conn.Connect(); err != nil {
return fmt.Errorf("irc connect failed: %w", err)
}
c.conn = conn
// ircevent.Connection.Loop() handles reconnection internally.
go conn.Loop()
c.SetRunning(true)
logger.InfoCF("irc", "IRC channel started", map[string]any{
"server": c.config.Server,
"nick": c.config.Nick,
})
return nil
}
// Stop disconnects from the IRC server.
func (c *IRCChannel) Stop(ctx context.Context) error {
logger.InfoC("irc", "Stopping IRC channel")
c.SetRunning(false)
if c.conn != nil {
c.conn.Quit()
}
if c.cancel != nil {
c.cancel()
}
logger.InfoC("irc", "IRC channel stopped")
return nil
}
// Send sends a message to an IRC channel or user.
func (c *IRCChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
target := msg.ChatID
if target == "" {
return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
}
if strings.TrimSpace(msg.Content) == "" {
return nil
}
// Send each line separately (IRC is line-oriented)
lines := strings.Split(msg.Content, "\n")
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if line == "" {
continue
}
c.conn.Privmsg(target, line)
}
logger.DebugCF("irc", "Message sent", map[string]any{
"target": target,
"lines": len(lines),
})
return nil
}
// StartTyping implements channels.TypingCapable using IRCv3 +typing client tag.
// Requires typing.enabled in config and server support for message-tags capability.
func (c *IRCChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
noop := func() {}
if !c.config.Typing.Enabled || !c.IsRunning() || c.conn == nil {
return noop, nil
}
// Check if server supports message-tags (required for TAGMSG)
if _, ok := c.conn.AcknowledgedCaps()["message-tags"]; !ok {
return noop, nil
}
c.conn.SendWithTags(map[string]string{"+typing": "active"}, "TAGMSG", chatID)
return func() {
if c.IsRunning() && c.conn != nil {
c.conn.SendWithTags(map[string]string{"+typing": "done"}, "TAGMSG", chatID)
}
}, nil
}
// extractHost returns the hostname portion of a host:port string.
func extractHost(server string) string {
host, _, found := strings.Cut(server, ":")
if found {
return host
}
return server
}
+145
View File
@@ -0,0 +1,145 @@
package irc
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestNewIRCChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing server", func(t *testing.T) {
cfg := config.IRCConfig{Nick: "bot"}
_, err := NewIRCChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing server, got nil")
}
})
t.Run("missing nick", func(t *testing.T) {
cfg := config.IRCConfig{Server: "irc.example.com:6667"}
_, err := NewIRCChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing nick, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.IRCConfig{
Server: "irc.example.com:6667",
Nick: "testbot",
Channels: []string{"#test"},
}
ch, err := NewIRCChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "irc" {
t.Errorf("Name() = %q, want %q", ch.Name(), "irc")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
}
func TestExtractHost(t *testing.T) {
tests := []struct {
server string
want string
}{
{"irc.libera.chat:6697", "irc.libera.chat"},
{"localhost:6667", "localhost"},
{"irc.example.com", "irc.example.com"},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.server, func(t *testing.T) {
got := extractHost(tt.server)
if got != tt.want {
t.Errorf("extractHost(%q) = %q, want %q", tt.server, got, tt.want)
}
})
}
}
func TestNickMentionedAt(t *testing.T) {
tests := []struct {
name string
content string
nick string
want int
}{
{"colon prefix", "bot: hello", "bot", 0},
{"comma prefix", "bot, hello", "bot", 0},
{"case insensitive", "BOT: hello", "bot", 0},
{"word boundary mid", "hey bot what's up", "bot", 4},
{"no mention", "hello world", "bot", -1},
{"substring mismatch", "robotics are cool", "bot", -1},
{"nick at end", "hello bot", "bot", 6},
{"empty content", "", "bot", -1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := nickMentionedAt(tt.content, tt.nick)
if got != tt.want {
t.Errorf("nickMentionedAt(%q, %q) = %d, want %d", tt.content, tt.nick, got, tt.want)
}
})
}
}
func TestIsBotMentioned(t *testing.T) {
tests := []struct {
name string
content string
nick string
want bool
}{
{"colon prefix", "bot: hello", "bot", true},
{"comma prefix", "bot, hello", "bot", true},
{"case insensitive", "BOT: hello", "bot", true},
{"word boundary mid", "hey bot what's up", "bot", true},
{"no mention", "hello world", "bot", false},
{"substring mismatch", "robotics are cool", "bot", false},
{"nick at end", "hello bot", "bot", true},
{"empty content", "", "bot", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isBotMentioned(tt.content, tt.nick)
if got != tt.want {
t.Errorf("isBotMentioned(%q, %q) = %v, want %v", tt.content, tt.nick, got, tt.want)
}
})
}
}
func TestStripBotMention(t *testing.T) {
tests := []struct {
name string
content string
nick string
want string
}{
{"colon prefix", "bot: hello there", "bot", "hello there"},
{"comma prefix", "bot, help me", "bot", "help me"},
{"case insensitive", "BOT: hello", "bot", "hello"},
{"no prefix match", "hello bot", "bot", "hello bot"},
{"only prefix", "bot:", "bot", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := stripBotMention(tt.content, tt.nick)
if got != tt.want {
t.Errorf("stripBotMention(%q, %q) = %q, want %q", tt.content, tt.nick, got, tt.want)
}
})
}
}
+5
View File
@@ -62,6 +62,7 @@ var channelRateConfig = map[string]float64{
"discord": 1,
"slack": 1,
"line": 10,
"irc": 2,
}
type channelWorker struct {
@@ -267,6 +268,10 @@ func (m *Manager) initChannels() error {
m.initChannel("pico", "Pico")
}
if m.config.Channels.IRC.Enabled && m.config.Channels.IRC.Server != "" {
m.initChannel("irc", "IRC")
}
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
"enabled_channels": len(m.channels),
})
@@ -0,0 +1,116 @@
package telegram
import (
"context"
"math/rand"
"slices"
"time"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/logger"
)
var commandRegistrationBackoff = []time.Duration{
5 * time.Second,
15 * time.Second,
60 * time.Second,
5 * time.Minute,
10 * time.Minute,
}
func commandRegistrationDelay(attempt int) time.Duration {
if len(commandRegistrationBackoff) == 0 {
return 0
}
base := commandRegistrationBackoff[min(attempt, len(commandRegistrationBackoff)-1)]
// Full jitter in [0.5, 1.0) to avoid synchronized retries across instances.
return time.Duration(float64(base) * (0.5 + rand.Float64()*0.5))
}
// RegisterCommands registers bot commands on Telegram platform.
func (c *TelegramChannel) RegisterCommands(ctx context.Context, defs []commands.Definition) error {
botCommands := make([]telego.BotCommand, 0, len(defs))
for _, def := range defs {
if def.Name == "" || def.Description == "" {
continue
}
botCommands = append(botCommands, telego.BotCommand{
Command: def.Name,
Description: def.Description,
})
}
current, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{})
if err != nil {
// If we can't read current commands, fall through to set them.
logger.WarnCF("telegram", "Failed to get current commands, will set unconditionally",
map[string]any{"error": err.Error()})
} else if slices.Equal(current, botCommands) {
logger.DebugCF("telegram", "Bot commands are up to date", nil)
return nil
}
return c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
Commands: botCommands,
})
}
func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []commands.Definition) {
if len(defs) == 0 {
return
}
register := c.registerFunc
if register == nil {
register = c.RegisterCommands
}
regCtx, cancel := context.WithCancel(ctx)
c.commandRegCancel = cancel
// Registration runs asynchronously so Telegram message intake is never blocked
// by temporary upstream API failures. Retry stops on success or channel shutdown.
go func() {
attempt := 0
timer := time.NewTimer(0)
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
defer timer.Stop()
for {
err := register(regCtx, defs)
if err == nil {
logger.InfoCF("telegram", "Telegram commands registered", map[string]any{
"count": len(defs),
})
return
}
delay := commandRegistrationDelay(attempt)
logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{
"error": err.Error(),
"retry_after": delay.String(),
})
attempt++
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(delay)
select {
case <-regCtx.Done():
return
case <-timer.C:
}
}
}()
}
@@ -0,0 +1,96 @@
package telegram
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/commands"
)
func TestStartCommandRegistration_DoesNotBlock(t *testing.T) {
ch := &TelegramChannel{}
started := make(chan struct{}, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch.registerFunc = func(context.Context, []commands.Definition) error {
started <- struct{}{}
return errors.New("temporary failure")
}
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help"}})
select {
case <-started:
case <-time.After(time.Second):
t.Fatal("registration did not start asynchronously")
}
}
func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) {
ch := &TelegramChannel{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
origBackoff := commandRegistrationBackoff
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
defer func() { commandRegistrationBackoff = origBackoff }()
var attempts atomic.Int32
ch.registerFunc = func(context.Context, []commands.Definition) error {
n := attempts.Add(1)
if n < 3 {
return errors.New("temporary failure")
}
return nil
}
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}})
deadline := time.Now().Add(250 * time.Millisecond)
for time.Now().Before(deadline) {
if attempts.Load() >= 3 {
break
}
time.Sleep(5 * time.Millisecond)
}
if attempts.Load() < 3 {
t.Fatalf("expected at least 3 attempts, got %d", attempts.Load())
}
stable := attempts.Load()
time.Sleep(30 * time.Millisecond)
if attempts.Load() != stable {
t.Fatalf("expected retries to stop after success, got %d -> %d", stable, attempts.Load())
}
}
func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) {
ch := &TelegramChannel{}
ctx, cancel := context.WithCancel(context.Background())
origBackoff := commandRegistrationBackoff
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
defer func() { commandRegistrationBackoff = origBackoff }()
defer cancel()
var attempts atomic.Int32
ch.registerFunc = func(context.Context, []commands.Definition) error {
attempts.Add(1)
return errors.New("always fail")
}
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}})
time.Sleep(20 * time.Millisecond)
cancel()
time.Sleep(20 * time.Millisecond) // allow in-flight attempt to settle
stable := attempts.Load()
time.Sleep(30 * time.Millisecond)
if attempts.Load() != stable {
t.Fatalf("expected retries to quiesce after cancel, got %d -> %d", stable, attempts.Load())
}
}
+115 -100
View File
@@ -7,7 +7,6 @@ import (
"net/url"
"os"
"regexp"
"slices"
"strconv"
"strings"
"time"
@@ -18,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -40,13 +40,15 @@ var (
type TelegramChannel struct {
*channels.BaseChannel
bot *telego.Bot
bh *th.BotHandler
commands TelegramCommander
config *config.Config
chatIDs map[string]int64
ctx context.Context
cancel context.CancelFunc
bot *telego.Bot
bh *th.BotHandler
config *config.Config
chatIDs map[string]int64
ctx context.Context
cancel context.CancelFunc
registerFunc func(context.Context, []commands.Definition) error
commandRegCancel context.CancelFunc
}
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
@@ -86,14 +88,13 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
telegramCfg,
bus,
telegramCfg.AllowFrom,
channels.WithMaxMessageLength(4096),
channels.WithMaxMessageLength(4000),
channels.WithGroupTrigger(telegramCfg.GroupTrigger),
channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID),
)
return &TelegramChannel{
BaseChannel: base,
commands: NewTelegramCommands(bot, cfg),
bot: bot,
config: cfg,
chatIDs: make(map[string]int64),
@@ -105,12 +106,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
c.ctx, c.cancel = context.WithCancel(ctx)
if err := c.initBotCommands(c.ctx); err != nil {
logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
"error": err.Error(),
})
}
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
@@ -126,21 +121,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
}
c.bh = bh
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Start(ctx, message)
}, th.CommandEqual("start"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Help(ctx, message)
}, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Show(ctx, message)
}, th.CommandEqual("show"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.List(ctx, message)
}, th.CommandEqual("list"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.handleMessage(ctx, &message)
}, th.AnyMessage())
@@ -150,6 +130,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
"username": c.bot.Username(),
})
c.startCommandRegistration(c.ctx, commands.BuiltinDefinitions())
go func() {
if err = bh.Start(); err != nil {
logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
@@ -174,50 +156,8 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
if c.cancel != nil {
c.cancel()
}
return nil
}
func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
Scope: tu.ScopeDefault(),
})
if err != nil {
return fmt.Errorf("get commands: %w", err)
}
commands := []telego.BotCommand{
{
Command: "start",
Description: "Start the bot",
},
{
Command: "help",
Description: "Show a help message",
},
{
Command: "show",
Description: "Show current configuration",
},
{
Command: "list",
Description: "List available options",
},
}
// Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
if !slices.Equal(currentCommands, commands) {
logger.InfoC("telegram", "Updating bot commands")
err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
Commands: commands,
Scope: tu.ScopeDefault(),
})
if err != nil {
return fmt.Errorf("set commands: %w", err)
}
} else {
logger.DebugC("telegram", "Bot commands are up to date")
if c.commandRegCancel != nil {
c.commandRegCancel()
}
return nil
@@ -233,22 +173,57 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
}
htmlContent := markdownToTelegramHTML(msg.Content)
if msg.Content == "" {
return nil
}
// Typing/placeholder handled by Manager.preSend — just send the message
// The Manager already splits messages to ≤4000 chars (WithMaxMessageLength),
// so msg.Content is guaranteed to be within that limit. We still need to
// check if HTML expansion pushes it beyond Telegram's 4096-char API limit.
queue := []string{msg.Content}
for len(queue) > 0 {
chunk := queue[0]
queue = queue[1:]
htmlContent := markdownToTelegramHTML(chunk)
if len([]rune(htmlContent)) > 4096 {
ratio := float64(len([]rune(chunk))) / float64(len([]rune(htmlContent)))
smallerLen := int(float64(4096) * ratio * 0.95) // 5% safety margin
if smallerLen < 100 {
smallerLen = 100
}
// Push sub-chunks back to the front of the queue for
// re-validation instead of sending them blindly.
subChunks := channels.SplitMessage(chunk, smallerLen)
queue = append(subChunks, queue...)
continue
}
if err := c.sendHTMLChunk(ctx, chatID, htmlContent, chunk); err != nil {
return err
}
}
return nil
}
// sendHTMLChunk sends a single HTML message, falling back to the original
// markdown as plain text on parse failure so users never see raw HTML tags.
func (c *TelegramChannel) sendHTMLChunk(ctx context.Context, chatID int64, htmlContent, mdFallback string) error {
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
tgMsg.ParseMode = telego.ModeHTML
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil {
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{
"error": err.Error(),
})
tgMsg.Text = mdFallback
tgMsg.ParseMode = ""
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
return fmt.Errorf("telegram send: %w", channels.ErrTemporary)
}
}
return nil
}
@@ -721,34 +696,34 @@ func escapeHTML(text string) string {
// isBotMentioned checks if the bot is mentioned in the message via entities.
func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
botUsername := c.bot.Username()
if botUsername == "" {
text, entities := telegramEntityTextAndList(message)
if text == "" || len(entities) == 0 {
return false
}
entities := message.Entities
if entities == nil {
entities = message.CaptionEntities
botUsername := ""
if c.bot != nil {
botUsername = c.bot.Username()
}
runes := []rune(text)
for _, entity := range entities {
if entity.Type == "mention" {
// Extract the mention text from the message
text := message.Text
if text == "" {
text = message.Caption
}
runes := []rune(text)
end := entity.Offset + entity.Length
if end <= len(runes) {
mention := string(runes[entity.Offset:end])
if strings.EqualFold(mention, "@"+botUsername) {
return true
}
}
entityText, ok := telegramEntityText(runes, entity)
if !ok {
continue
}
if entity.Type == "text_mention" && entity.User != nil {
if entity.User.Username == botUsername {
switch entity.Type {
case telego.EntityTypeMention:
if botUsername != "" && strings.EqualFold(entityText, "@"+botUsername) {
return true
}
case telego.EntityTypeTextMention:
if botUsername != "" && entity.User != nil && strings.EqualFold(entity.User.Username, botUsername) {
return true
}
case telego.EntityTypeBotCommand:
if isBotCommandEntityForThisBot(entityText, botUsername) {
return true
}
}
@@ -756,6 +731,46 @@ func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
return false
}
func telegramEntityTextAndList(message *telego.Message) (string, []telego.MessageEntity) {
if message.Text != "" {
return message.Text, message.Entities
}
return message.Caption, message.CaptionEntities
}
func telegramEntityText(runes []rune, entity telego.MessageEntity) (string, bool) {
if entity.Offset < 0 || entity.Length <= 0 {
return "", false
}
end := entity.Offset + entity.Length
if entity.Offset >= len(runes) || end > len(runes) {
return "", false
}
return string(runes[entity.Offset:end]), true
}
func isBotCommandEntityForThisBot(entityText, botUsername string) bool {
if !strings.HasPrefix(entityText, "/") {
return false
}
command := strings.TrimPrefix(entityText, "/")
if command == "" {
return false
}
at := strings.IndexRune(command, '@')
if at == -1 {
// A bare /command delivered to this bot is intended for this bot.
return true
}
mentionUsername := command[at+1:]
if mentionUsername == "" || botUsername == "" {
return false
}
return strings.EqualFold(mentionUsername, botUsername)
}
// stripBotMention removes the @bot mention from the content.
func (c *TelegramChannel) stripBotMention(content string) string {
botUsername := c.bot.Username()
-156
View File
@@ -1,156 +0,0 @@
package telegram
import (
"context"
"fmt"
"strings"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/config"
)
type TelegramCommander interface {
Help(ctx context.Context, message telego.Message) error
Start(ctx context.Context, message telego.Message) error
Show(ctx context.Context, message telego.Message) error
List(ctx context.Context, message telego.Message) error
}
type cmd struct {
bot *telego.Bot
config *config.Config
}
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
return &cmd{
bot: bot,
config: cfg,
}
}
func commandArgs(text string) string {
parts := strings.SplitN(text, " ", 2)
if len(parts) < 2 {
return ""
}
return strings.TrimSpace(parts[1])
}
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
msg := `/start - Start the bot
/help - Show this help message
/show [model|channel] - Show current configuration
/list [models|channels] - List available options
`
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: msg,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Hello! I am PicoClaw 🦞",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /show [model|channel]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "model":
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
c.config.Agents.Defaults.GetModelName(),
c.config.Agents.Defaults.Provider)
case "channel":
response = "Current Channel: telegram"
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) List(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /list [models|channels]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "models":
provider := c.config.Agents.Defaults.Provider
if provider == "" {
provider = "configured default"
}
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
c.config.Agents.Defaults.GetModelName(), provider)
case "channels":
var enabled []string
if c.config.Channels.Telegram.Enabled {
enabled = append(enabled, "telegram")
}
if c.config.Channels.WhatsApp.Enabled {
enabled = append(enabled, "whatsapp")
}
if c.config.Channels.Feishu.Enabled {
enabled = append(enabled, "feishu")
}
if c.config.Channels.Discord.Enabled {
enabled = append(enabled, "discord")
}
if c.config.Channels.Slack.Enabled {
enabled = append(enabled, "slack")
}
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
@@ -0,0 +1,52 @@
package telegram
import (
"context"
"testing"
"time"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
)
func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
msg := &telego.Message{
Text: "/new",
MessageID: 9,
Chat: telego.Chat{
ID: 123,
Type: "private",
},
From: &telego.User{
ID: 42,
FirstName: "Alice",
},
}
if err := ch.handleMessage(context.Background(), msg); err != nil {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "telegram" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
}
@@ -0,0 +1,147 @@
package telegram
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
type getMeCaller struct {
username string
}
func (c getMeCaller) Call(_ context.Context, url string, _ *ta.RequestData) (*ta.Response, error) {
if strings.HasSuffix(url, "/getMe") {
result := fmt.Sprintf(`{"id":1,"is_bot":true,"first_name":"bot","username":%q}`, c.username)
return &ta.Response{Ok: true, Result: []byte(result)}, nil
}
return &ta.Response{Ok: true, Result: []byte("true")}, nil
}
func newTestTelegramBot(t *testing.T, username string) *telego.Bot {
t.Helper()
token := "123456:" + strings.Repeat("a", 35)
bot, err := telego.NewBot(token,
telego.WithAPICaller(getMeCaller{username: username}),
telego.WithDiscardLogger(),
)
if err != nil {
t.Fatalf("NewBot error: %v", err)
}
return bot
}
func newGroupMentionOnlyChannel(t *testing.T, botUsername string) (*TelegramChannel, *bus.MessageBus) {
t.Helper()
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil,
channels.WithGroupTrigger(config.GroupTriggerConfig{MentionOnly: true}),
),
bot: newTestTelegramBot(t, botUsername),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
return ch, messageBus
}
func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
tests := []struct {
name string
text string
wantForwarded bool
wantContent string
}{
{
name: "command with bot username",
text: "/new@testbot",
wantForwarded: true,
wantContent: "/new",
},
{
name: "bare command",
text: "/new",
wantForwarded: true,
wantContent: "/new",
},
{
name: "command for another bot",
text: "/new@otherbot",
wantForwarded: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ch, messageBus := newGroupMentionOnlyChannel(t, "testbot")
msg := &telego.Message{
Text: tc.text,
Entities: []telego.MessageEntity{{
Type: telego.EntityTypeBotCommand,
Offset: 0,
Length: len([]rune(tc.text)),
}},
MessageID: 42,
Chat: telego.Chat{
ID: 123,
Type: "group",
},
From: &telego.User{
ID: 7,
FirstName: "Alice",
},
}
if err := ch.handleMessage(context.Background(), msg); err != nil {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Content != tc.wantContent {
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
}
return
}
if ok {
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
}
})
}
}
func TestIsBotMentioned_MentionEntityUnaffected(t *testing.T) {
ch, _ := newGroupMentionOnlyChannel(t, "testbot")
msg := &telego.Message{
Text: "@testbot hello",
Entities: []telego.MessageEntity{{
Type: telego.EntityTypeMention,
Offset: 0,
Length: len("@testbot"),
}},
}
if !ch.isBotMentioned(msg) {
t.Fatal("expected mention entity to be treated as bot mention")
}
}
+273
View File
@@ -0,0 +1,273 @@
package telegram
import (
"context"
"encoding/json"
"errors"
"strings"
"testing"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
)
const testToken = "1234567890:aaaabbbbaaaabbbbaaaabbbbaaaabbbbccc"
// stubCaller implements ta.Caller for testing.
type stubCaller struct {
calls []stubCall
callFn func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error)
}
type stubCall struct {
URL string
Data *ta.RequestData
}
func (s *stubCaller) Call(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
s.calls = append(s.calls, stubCall{URL: url, Data: data})
return s.callFn(ctx, url, data)
}
// stubConstructor implements ta.RequestConstructor for testing.
type stubConstructor struct{}
func (s *stubConstructor) JSONRequest(parameters any) (*ta.RequestData, error) {
return &ta.RequestData{}, nil
}
func (s *stubConstructor) MultipartRequest(
parameters map[string]string,
files map[string]ta.NamedReader,
) (*ta.RequestData, error) {
return &ta.RequestData{}, nil
}
// successResponse returns a ta.Response that telego will treat as a successful SendMessage.
func successResponse(t *testing.T) *ta.Response {
t.Helper()
msg := &telego.Message{MessageID: 1}
b, err := json.Marshal(msg)
require.NoError(t, err)
return &ta.Response{Ok: true, Result: b}
}
// newTestChannel creates a TelegramChannel with a mocked bot for unit testing.
func newTestChannel(t *testing.T, caller *stubCaller) *TelegramChannel {
t.Helper()
bot, err := telego.NewBot(testToken,
telego.WithAPICaller(caller),
telego.WithRequestConstructor(&stubConstructor{}),
telego.WithDiscardLogger(),
)
require.NoError(t, err)
base := channels.NewBaseChannel("telegram", nil, nil, nil,
channels.WithMaxMessageLength(4000),
)
base.SetRunning(true)
return &TelegramChannel{
BaseChannel: base,
bot: bot,
chatIDs: make(map[string]int64),
}
}
func TestSend_EmptyContent(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
t.Fatal("SendMessage should not be called for empty content")
return nil, nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "",
})
assert.NoError(t, err)
assert.Empty(t, caller.calls, "no API calls should be made for empty content")
}
func TestSend_ShortMessage_SingleCall(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "Hello, world!",
})
assert.NoError(t, err)
assert.Len(t, caller.calls, 1, "short message should result in exactly one SendMessage call")
}
func TestSend_LongMessage_SingleCall(t *testing.T) {
// With WithMaxMessageLength(4000), the Manager pre-splits messages before
// they reach Send(). A message at exactly 4000 chars should go through
// as a single SendMessage call (no re-split needed since HTML expansion
// won't exceed 4096 for plain text).
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
longContent := strings.Repeat("a", 4000)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: longContent,
})
assert.NoError(t, err)
assert.Len(t, caller.calls, 1, "pre-split message within limit should result in one SendMessage call")
}
func TestSend_HTMLFallback_PerChunk(t *testing.T) {
callCount := 0
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
callCount++
// Fail on odd calls (HTML attempt), succeed on even calls (plain text fallback)
if callCount%2 == 1 {
return nil, errors.New("Bad Request: can't parse entities")
}
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "Hello **world**",
})
assert.NoError(t, err)
// One short message → 1 HTML attempt (fail) + 1 plain text fallback (success) = 2 calls
assert.Equal(t, 2, len(caller.calls), "should have HTML attempt + plain text fallback")
}
func TestSend_HTMLFallback_BothFail(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return nil, errors.New("send failed")
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "Hello",
})
assert.Error(t, err)
assert.True(t, errors.Is(err, channels.ErrTemporary), "error should wrap ErrTemporary")
assert.Equal(t, 2, len(caller.calls), "should have HTML attempt + plain text attempt")
}
func TestSend_LongMessage_HTMLFallback_StopsOnError(t *testing.T) {
// With a long message that gets split into 2 chunks, if both HTML and
// plain text fail on the first chunk, Send should return early.
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return nil, errors.New("send failed")
},
}
ch := newTestChannel(t, caller)
longContent := strings.Repeat("x", 4001)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: longContent,
})
assert.Error(t, err)
// Should fail on the first chunk (2 calls: HTML + fallback), never reaching the second chunk.
assert.Equal(t, 2, len(caller.calls), "should stop after first chunk fails both HTML and plain text")
}
func TestSend_MarkdownShortButHTMLLong_MultipleCalls(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
return successResponse(t), nil
},
}
ch := newTestChannel(t, caller)
// Create markdown whose length is <= 4000 but whose HTML expansion is much longer.
// "**a** " (6 chars) becomes "<b>a</b> " (9 chars) in HTML, so repeating it many times
// yields HTML that exceeds Telegram's limit while markdown stays within it.
markdownContent := strings.Repeat("**a** ", 600) // 3600 chars markdown, HTML ~5400+ chars
assert.LessOrEqual(t, len([]rune(markdownContent)), 4000, "markdown content must not exceed chunk size")
htmlExpanded := markdownToTelegramHTML(markdownContent)
assert.Greater(
t, len([]rune(htmlExpanded)), 4096,
"HTML expansion must exceed Telegram limit for this test to be meaningful",
)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: markdownContent,
})
assert.NoError(t, err)
assert.Greater(
t, len(caller.calls), 1,
"markdown-short but HTML-long message should be split into multiple SendMessage calls",
)
}
func TestSend_NotRunning(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
t.Fatal("should not be called")
return nil, nil
},
}
ch := newTestChannel(t, caller)
ch.SetRunning(false)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "12345",
Content: "Hello",
})
assert.ErrorIs(t, err, channels.ErrNotRunning)
assert.Empty(t, caller.calls)
}
func TestSend_InvalidChatID(t *testing.T) {
caller := &stubCaller{
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
t.Fatal("should not be called")
return nil, nil
},
}
ch := newTestChannel(t, caller)
err := ch.Send(context.Background(), bus.OutboundMessage{
ChatID: "not-a-number",
Content: "Hello",
})
assert.Error(t, err)
assert.True(t, errors.Is(err, channels.ErrSendFailed), "error should wrap ErrSendFailed")
assert.Empty(t, caller.calls)
}
@@ -0,0 +1,41 @@
package whatsapp
import (
"context"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &WhatsAppChannel{
BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil),
ctx: context.Background(),
}
ch.handleIncomingMessage(map[string]any{
"type": "message",
"id": "mid1",
"from": "user1",
"chat": "chat1",
"content": "/help",
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "whatsapp" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/help" {
t.Fatalf("content=%q", inbound.Content)
}
}
@@ -0,0 +1,56 @@
//go:build whatsapp_native
package whatsapp
import (
"context"
"testing"
"time"
"go.mau.fi/whatsmeow/proto/waE2E"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow/types/events"
"google.golang.org/protobuf/proto"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &WhatsAppNativeChannel{
BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil),
runCtx: context.Background(),
}
evt := &events.Message{
Info: types.MessageInfo{
MessageSource: types.MessageSource{
Sender: types.NewJID("1001", types.DefaultUserServer),
Chat: types.NewJID("1001", types.DefaultUserServer),
},
ID: "mid1",
PushName: "Alice",
},
Message: &waE2E.Message{
Conversation: proto.String("/new"),
},
}
ch.handleIncoming(evt)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "whatsapp_native" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
}
+16
View File
@@ -0,0 +1,16 @@
package commands
// BuiltinDefinitions returns all built-in command definitions.
// Each command group is defined in its own cmd_*.go file.
// Definitions are stateless — runtime dependencies are provided
// via the Runtime parameter passed to handlers at execution time.
func BuiltinDefinitions() []Definition {
return []Definition{
startCommand(),
helpCommand(),
showCommand(),
listCommand(),
switchCommand(),
checkCommand(),
}
}
+145
View File
@@ -0,0 +1,145 @@
package commands
import (
"context"
"strings"
"testing"
)
func findDefinitionByName(t *testing.T, defs []Definition, name string) Definition {
t.Helper()
for _, def := range defs {
if def.Name == name {
return def
}
}
t.Fatalf("missing /%s definition", name)
return Definition{}
}
func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
defs := BuiltinDefinitions()
helpDef := findDefinitionByName(t, defs, "help")
if helpDef.Handler == nil {
t.Fatalf("/help handler should not be nil")
}
var reply string
err := helpDef.Handler(context.Background(), Request{
Text: "/help",
Reply: func(text string) error {
reply = text
return nil
},
}, nil)
if err != nil {
t.Fatalf("/help handler error: %v", err)
}
// Now uses auto-generated EffectiveUsage which includes agents
if !strings.Contains(reply, "/show [model|channel|agents]") {
t.Fatalf("/help reply missing /show usage, got %q", reply)
}
if !strings.Contains(reply, "/list [models|channels|agents]") {
t.Fatalf("/help reply missing /list usage, got %q", reply)
}
}
func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) {
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), nil)
cases := []string{"telegram", "whatsapp"}
for _, channel := range cases {
var reply string
res := ex.Execute(context.Background(), Request{
Channel: channel,
Text: "/show channel",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/show channel on %s: outcome=%v, want=%v", channel, res.Outcome, OutcomeHandled)
}
want := "Current Channel: " + channel
if reply != want {
t.Fatalf("/show channel reply=%q, want=%q", reply, want)
}
}
}
func TestBuiltinListChannels_UsesGetEnabledChannels(t *testing.T) {
rt := &Runtime{
GetEnabledChannels: func() []string {
return []string{"telegram", "slack"}
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/list channels",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/list channels: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !strings.Contains(reply, "telegram") || !strings.Contains(reply, "slack") {
t.Fatalf("/list channels reply=%q, want telegram and slack", reply)
}
}
func TestBuiltinShowAgents_RestoresOldBehavior(t *testing.T) {
rt := &Runtime{
ListAgentIDs: func() []string {
return []string{"default", "coder"}
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/show agents",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/show agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") {
t.Fatalf("/show agents reply=%q, want agent IDs", reply)
}
}
func TestBuiltinListAgents_RestoresOldBehavior(t *testing.T) {
rt := &Runtime{
ListAgentIDs: func() []string {
return []string{"default", "coder"}
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/list agents",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/list agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") {
t.Fatalf("/list agents reply=%q, want agent IDs", reply)
}
}
+33
View File
@@ -0,0 +1,33 @@
package commands
import (
"context"
"fmt"
)
func checkCommand() Definition {
return Definition{
Name: "check",
Description: "Check channel availability",
SubCommands: []SubCommand{
{
Name: "channel",
Description: "Check if a channel is available",
ArgsUsage: "<name>",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.SwitchChannel == nil {
return req.Reply(unavailableMsg)
}
value := nthToken(req.Text, 2)
if value == "" {
return req.Reply("Usage: /check channel <name>")
}
if err := rt.SwitchChannel(value); err != nil {
return req.Reply(err.Error())
}
return req.Reply(fmt.Sprintf("Channel '%s' is available and enabled", value))
},
},
},
}
}
+44
View File
@@ -0,0 +1,44 @@
package commands
import (
"context"
"fmt"
"strings"
)
func helpCommand() Definition {
return Definition{
Name: "help",
Description: "Show this help message",
Usage: "/help",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
var defs []Definition
if rt != nil && rt.ListDefinitions != nil {
defs = rt.ListDefinitions()
} else {
defs = BuiltinDefinitions()
}
return req.Reply(formatHelpMessage(defs))
},
}
}
func formatHelpMessage(defs []Definition) string {
if len(defs) == 0 {
return "No commands available."
}
lines := make([]string, 0, len(defs))
for _, def := range defs {
usage := def.EffectiveUsage()
if usage == "" {
usage = "/" + def.Name
}
desc := def.Description
if desc == "" {
desc = "No description"
}
lines = append(lines, fmt.Sprintf("%s - %s", usage, desc))
}
return strings.Join(lines, "\n")
}
+52
View File
@@ -0,0 +1,52 @@
package commands
import (
"context"
"fmt"
"strings"
)
func listCommand() Definition {
return Definition{
Name: "list",
Description: "List available options",
SubCommands: []SubCommand{
{
Name: "models",
Description: "Configured models",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.GetModelInfo == nil {
return req.Reply(unavailableMsg)
}
name, provider := rt.GetModelInfo()
if provider == "" {
provider = "configured default"
}
return req.Reply(fmt.Sprintf(
"Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
name, provider,
))
},
},
{
Name: "channels",
Description: "Enabled channels",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.GetEnabledChannels == nil {
return req.Reply(unavailableMsg)
}
enabled := rt.GetEnabledChannels()
if len(enabled) == 0 {
return req.Reply("No channels enabled")
}
return req.Reply(fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")))
},
},
{
Name: "agents",
Description: "Registered agents",
Handler: agentsHandler(),
},
},
}
}
+38
View File
@@ -0,0 +1,38 @@
package commands
import (
"context"
"fmt"
)
func showCommand() Definition {
return Definition{
Name: "show",
Description: "Show current configuration",
SubCommands: []SubCommand{
{
Name: "model",
Description: "Current model and provider",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.GetModelInfo == nil {
return req.Reply(unavailableMsg)
}
name, provider := rt.GetModelInfo()
return req.Reply(fmt.Sprintf("Current Model: %s (Provider: %s)", name, provider))
},
},
{
Name: "channel",
Description: "Current channel",
Handler: func(_ context.Context, req Request, _ *Runtime) error {
return req.Reply(fmt.Sprintf("Current Channel: %s", req.Channel))
},
},
{
Name: "agents",
Description: "Registered agents",
Handler: agentsHandler(),
},
},
}
}
+14
View File
@@ -0,0 +1,14 @@
package commands
import "context"
func startCommand() Definition {
return Definition{
Name: "start",
Description: "Start the bot",
Usage: "/start",
Handler: func(_ context.Context, req Request, _ *Runtime) error {
return req.Reply("Hello! I am PicoClaw 🦞")
},
}
}
+42
View File
@@ -0,0 +1,42 @@
package commands
import (
"context"
"fmt"
)
func switchCommand() Definition {
return Definition{
Name: "switch",
Description: "Switch model",
SubCommands: []SubCommand{
{
Name: "model",
Description: "Switch to a different model",
ArgsUsage: "to <name>",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.SwitchModel == nil {
return req.Reply(unavailableMsg)
}
// Parse: /switch model to <value>
value := nthToken(req.Text, 3) // tokens: [/switch, model, to, <value>]
if nthToken(req.Text, 2) != "to" || value == "" {
return req.Reply("Usage: /switch model to <name>")
}
oldModel, err := rt.SwitchModel(value)
if err != nil {
return req.Reply(err.Error())
}
return req.Reply(fmt.Sprintf("Switched model from %s to %s", oldModel, value))
},
},
{
Name: "channel",
Description: "Moved to /check channel",
Handler: func(_ context.Context, req Request, _ *Runtime) error {
return req.Reply("This command has moved. Please use: /check channel <name>")
},
},
},
}
}
+279
View File
@@ -0,0 +1,279 @@
package commands
import (
"context"
"fmt"
"testing"
)
func TestSwitchModel_Success(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "old-model", nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model to gpt-4",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
want := "Switched model from old-model to gpt-4"
if reply != want {
t.Fatalf("reply=%q, want=%q", reply, want)
}
}
func TestSwitchModel_MissingToKeyword(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "old", nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model gpt-4",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Usage: /switch model to <name>" {
t.Fatalf("reply=%q, want usage message", reply)
}
}
func TestSwitchModel_MissingValue(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "old", nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model to",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Usage: /switch model to <name>" {
t.Fatalf("reply=%q, want usage message", reply)
}
}
func TestSwitchModel_Error(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "", fmt.Errorf("model not found")
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model to bad-model",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "model not found" {
t.Fatalf("reply=%q, want error message", reply)
}
}
func TestSwitchModel_NilDep(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model to gpt-4",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Command unavailable in current context." {
t.Fatalf("reply=%q, want unavailable message", reply)
}
}
func TestSwitchChannel_Redirect(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch channel to telegram",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
want := "This command has moved. Please use: /check channel <name>"
if reply != want {
t.Fatalf("reply=%q, want=%q", reply, want)
}
}
func TestCheckChannel_Success(t *testing.T) {
rt := &Runtime{
SwitchChannel: func(value string) error {
return nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/check channel telegram",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
want := "Channel 'telegram' is available and enabled"
if reply != want {
t.Fatalf("reply=%q, want=%q", reply, want)
}
}
func TestCheckChannel_Error(t *testing.T) {
rt := &Runtime{
SwitchChannel: func(value string) error {
return fmt.Errorf("channel '%s' not found", value)
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/check channel unknown",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "channel 'unknown' not found" {
t.Fatalf("reply=%q, want error message", reply)
}
}
func TestCheckChannel_NilDep(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/check channel telegram",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Command unavailable in current context." {
t.Fatalf("reply=%q, want unavailable message", reply)
}
}
func TestCheckChannel_MissingValue(t *testing.T) {
rt := &Runtime{
SwitchChannel: func(value string) error {
return nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/check channel",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Usage: /check channel <name>" {
t.Fatalf("reply=%q, want usage message", reply)
}
}
func TestSwitch_BangPrefix(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "old", nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "!switch model to gpt-4",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("! prefix: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Switched model from old to gpt-4" {
t.Fatalf("! prefix: reply=%q, want success message", reply)
}
}
func TestSwitch_NoSubCommand(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
// Should get usage message from executor's sub-command routing
if reply == "" {
t.Fatal("expected usage reply for bare /switch")
}
}
+48
View File
@@ -0,0 +1,48 @@
package commands
import (
"fmt"
"strings"
)
// SubCommand defines a single sub-command within a parent command.
type SubCommand struct {
Name string
Description string
ArgsUsage string // optional, e.g. "<session-id>"
Handler Handler
}
// Definition is the single-source metadata and behavior contract for a slash command.
//
// Design notes (phase 1):
// - Every channel reads command shape from this type instead of keeping local copies.
// - Visibility is global: all definitions are considered available to all channels.
// - Platform menu registration (for example Telegram BotCommand) also derives from this
// same definition so UI labels and runtime behavior stay aligned.
type Definition struct {
Name string
Description string
Usage string // for simple commands; ignored when SubCommands is set
Aliases []string
SubCommands []SubCommand // optional; when set, Executor routes to sub-command handlers
Handler Handler // for simple commands without sub-commands
}
// EffectiveUsage returns the usage string. When SubCommands are present,
// it is auto-generated from sub-command names so metadata and behavior
// cannot drift.
func (d Definition) EffectiveUsage() string {
if len(d.SubCommands) == 0 {
return d.Usage
}
names := make([]string, 0, len(d.SubCommands))
for _, sc := range d.SubCommands {
name := sc.Name
if sc.ArgsUsage != "" {
name += " " + sc.ArgsUsage
}
names = append(names, name)
}
return fmt.Sprintf("/%s [%s]", d.Name, strings.Join(names, "|"))
}
+41
View File
@@ -0,0 +1,41 @@
package commands
import (
"testing"
)
func TestDefinition_EffectiveUsage_NoSubCommands(t *testing.T) {
d := Definition{Name: "start", Usage: "/start"}
if got := d.EffectiveUsage(); got != "/start" {
t.Fatalf("EffectiveUsage()=%q, want %q", got, "/start")
}
}
func TestDefinition_EffectiveUsage_WithSubCommands(t *testing.T) {
d := Definition{
Name: "show",
SubCommands: []SubCommand{
{Name: "model"},
{Name: "channel"},
{Name: "agents"},
},
}
want := "/show [model|channel|agents]"
if got := d.EffectiveUsage(); got != want {
t.Fatalf("EffectiveUsage()=%q, want %q", got, want)
}
}
func TestDefinition_EffectiveUsage_WithArgsUsage(t *testing.T) {
d := Definition{
Name: "session",
SubCommands: []SubCommand{
{Name: "list"},
{Name: "resume", ArgsUsage: "<id>"},
},
}
want := "/session [list|resume <id>]"
if got := d.EffectiveUsage(); got != want {
t.Fatalf("EffectiveUsage()=%q, want %q", got, want)
}
}
+89
View File
@@ -0,0 +1,89 @@
package commands
import (
"context"
"fmt"
)
type Outcome int
const (
// OutcomePassthrough means this input should continue through normal agent flow.
OutcomePassthrough Outcome = iota
// OutcomeHandled means a command handler executed (with or without handler error).
OutcomeHandled
)
type ExecuteResult struct {
Outcome Outcome
Command string
Err error
}
type Executor struct {
reg *Registry
rt *Runtime
}
func NewExecutor(reg *Registry, rt *Runtime) *Executor {
return &Executor{reg: reg, rt: rt}
}
// Execute implements a two-state command decision:
// 1) handled: execute command immediately;
// 2) passthrough: not a command or intentionally deferred to agent logic.
func (e *Executor) Execute(ctx context.Context, req Request) ExecuteResult {
cmdName, ok := parseCommandName(req.Text)
if !ok {
return ExecuteResult{Outcome: OutcomePassthrough}
}
if e == nil || e.reg == nil {
return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName}
}
def, found := e.reg.Lookup(cmdName)
if !found {
return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName}
}
return e.executeDefinition(ctx, req, def)
}
func (e *Executor) executeDefinition(ctx context.Context, req Request, def Definition) ExecuteResult {
// Ensure Reply is always non-nil so handlers don't need to check.
if req.Reply == nil {
req.Reply = func(string) error { return nil }
}
// Simple command — no sub-commands
if len(def.SubCommands) == 0 {
if def.Handler == nil {
return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name}
}
err := def.Handler(ctx, req, e.rt)
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
}
// Sub-command routing
subName := nthToken(req.Text, 1)
if subName == "" {
err := req.Reply("Usage: " + def.EffectiveUsage())
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
}
normalized := normalizeCommandName(subName)
for _, sc := range def.SubCommands {
if normalizeCommandName(sc.Name) == normalized {
if sc.Handler == nil {
return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name}
}
err := sc.Handler(ctx, req, e.rt)
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
}
}
// Unknown sub-command
err := req.Reply(fmt.Sprintf("Unknown option: %s. Usage: %s", subName, def.EffectiveUsage()))
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
}
+260
View File
@@ -0,0 +1,260 @@
package commands
import (
"context"
"errors"
"strings"
"testing"
)
func TestExecutor_RegisteredWithoutHandler_ReturnsPassthrough(t *testing.T) {
defs := []Definition{{Name: "show"}}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/show"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
}
func TestExecutor_UnknownSlashCommand_ReturnsPassthrough(t *testing.T) {
defs := []Definition{{Name: "show"}}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/unknown"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
}
func TestExecutor_SupportedCommandWithHandler_ReturnsHandled(t *testing.T) {
called := false
defs := []Definition{
{
Name: "help",
Handler: func(context.Context, Request, *Runtime) error {
called = true
return nil
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help@my_bot"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !called {
t.Fatalf("expected handler to be called")
}
}
func TestExecutor_AliasWithoutHandler_ReturnsPassthrough(t *testing.T) {
defs := []Definition{
{
Name: "show",
Aliases: []string{"display"},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/display"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
if res.Command != "show" {
t.Fatalf("command=%q, want=%q", res.Command, "show")
}
}
func TestExecutor_AliasWithHandler_ReturnsHandled(t *testing.T) {
called := false
defs := []Definition{
{
Name: "clear",
Aliases: []string{"reset"},
Handler: func(context.Context, Request, *Runtime) error {
called = true
return nil
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/reset"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if res.Command != "clear" {
t.Fatalf("command=%q, want=%q", res.Command, "clear")
}
if !called {
t.Fatalf("expected handler to be called")
}
}
func TestExecutor_SupportedCommandWithNilHandler_ReturnsPassthrough(t *testing.T) {
defs := []Definition{
{Name: "placeholder"},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder list"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
if res.Command != "placeholder" {
t.Fatalf("command=%q, want=%q", res.Command, "placeholder")
}
}
func TestExecutor_NilHandlerDoesNotMaskLaterHandler(t *testing.T) {
// With Lookup-based dispatch, the first registered definition for a name wins.
// A definition with nil Handler and no SubCommands returns Passthrough.
defs := []Definition{
{Name: "placeholder"},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
if res.Command != "placeholder" {
t.Fatalf("command=%q, want=%q", res.Command, "placeholder")
}
}
func TestExecutor_HandlerErrorIsPropagated(t *testing.T) {
wantErr := errors.New("handler failed")
defs := []Definition{
{
Name: "help",
Handler: func(context.Context, Request, *Runtime) error {
return wantErr
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !errors.Is(res.Err, wantErr) {
t.Fatalf("err=%v, want=%v", res.Err, wantErr)
}
}
func TestExecutor_SupportsBangPrefixAndCaseInsensitiveCommand(t *testing.T) {
called := false
defs := []Definition{
{
Name: "help",
Handler: func(context.Context, Request, *Runtime) error {
called = true
return nil
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "!HELP"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !called {
t.Fatalf("expected handler to be called")
}
}
func TestExecutor_SubCommand_RoutesToCorrectHandler(t *testing.T) {
modelCalled := false
defs := []Definition{
{
Name: "show",
SubCommands: []SubCommand{
{Name: "model", Handler: func(_ context.Context, _ Request, _ *Runtime) error {
modelCalled = true
return nil
}},
{Name: "channel"},
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Text: "/show model"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !modelCalled {
t.Fatal("model sub-command handler was not called")
}
}
func TestExecutor_SubCommand_NoArg_RepliesUsage(t *testing.T) {
defs := []Definition{
{
Name: "show",
SubCommands: []SubCommand{
{Name: "model"},
{Name: "channel"},
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/show",
Reply: func(text string) error { reply = text; return nil },
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Usage: /show [model|channel]" {
t.Fatalf("reply=%q, want usage message", reply)
}
}
func TestExecutor_SubCommand_UnknownArg_RepliesError(t *testing.T) {
defs := []Definition{
{
Name: "show",
SubCommands: []SubCommand{
{Name: "model"},
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/show foobar",
Reply: func(text string) error { reply = text; return nil },
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !strings.Contains(reply, "foobar") {
t.Fatalf("reply=%q, should mention unknown sub-command", reply)
}
}
func TestExecutor_SubCommand_NilHandler_ReturnsPassthrough(t *testing.T) {
defs := []Definition{
{
Name: "show",
SubCommands: []SubCommand{
{Name: "model"}, // nil Handler
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Text: "/show model"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
}
+21
View File
@@ -0,0 +1,21 @@
package commands
import (
"context"
"fmt"
"strings"
)
// agentsHandler returns a shared handler for both /show agents and /list agents.
func agentsHandler() Handler {
return func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.ListAgentIDs == nil {
return req.Reply(unavailableMsg)
}
ids := rt.ListAgentIDs()
if len(ids) == 0 {
return req.Reply("No agents registered")
}
return req.Reply(fmt.Sprintf("Registered agents: %s", strings.Join(ids, ", ")))
}
}
+55
View File
@@ -0,0 +1,55 @@
package commands
type Registry struct {
defs []Definition
index map[string]int
}
// NewRegistry stores the canonical command set used by both dispatch and
// optional platform registration adapters.
func NewRegistry(defs []Definition) *Registry {
stored := make([]Definition, len(defs))
copy(stored, defs)
index := make(map[string]int, len(stored)*2)
for i, def := range stored {
registerCommandName(index, def.Name, i)
for _, alias := range def.Aliases {
registerCommandName(index, alias, i)
}
}
return &Registry{defs: stored, index: index}
}
// Definitions returns all registered command definitions.
// Command availability is global and no longer channel-scoped.
func (r *Registry) Definitions() []Definition {
out := make([]Definition, len(r.defs))
copy(out, r.defs)
return out
}
// Lookup returns a command definition by normalized command name or alias.
func (r *Registry) Lookup(name string) (Definition, bool) {
key := normalizeCommandName(name)
if key == "" {
return Definition{}, false
}
idx, ok := r.index[key]
if !ok {
return Definition{}, false
}
return r.defs[idx], true
}
func registerCommandName(index map[string]int, name string, defIndex int) {
key := normalizeCommandName(name)
if key == "" {
return
}
if _, exists := index[key]; exists {
return
}
index[key] = defIndex
}
+49
View File
@@ -0,0 +1,49 @@
package commands
import "testing"
func TestRegistry_Definitions_ReturnsCopy(t *testing.T) {
defs := []Definition{
{Name: "help", Description: "Show help"},
{Name: "admin", Description: "Admin command"},
}
r := NewRegistry(defs)
got := r.Definitions()
if len(got) != 2 {
t.Fatalf("definitions len = %d, want 2", len(got))
}
got[0].Name = "mutated"
again := r.Definitions()
if again[0].Name != "help" {
t.Fatalf("registry should not be mutated by caller, got first name %q", again[0].Name)
}
}
func TestRegistry_Lookup_MatchesByLowercaseNameAndAlias(t *testing.T) {
r := NewRegistry([]Definition{
{Name: "Help", Aliases: []string{"Assist"}},
{Name: "List"},
})
def, ok := r.Lookup("help")
if !ok || def.Name != "Help" {
t.Fatalf("lookup by lowercase name failed: ok=%v def=%+v", ok, def)
}
def, ok = r.Lookup("HELP")
if !ok || def.Name != "Help" {
t.Fatalf("lookup by uppercase name failed: ok=%v def=%+v", ok, def)
}
def, ok = r.Lookup("assist")
if !ok || def.Name != "Help" {
t.Fatalf("lookup by lowercase alias failed: ok=%v def=%+v", ok, def)
}
def, ok = r.Lookup("ASSIST")
if !ok || def.Name != "Help" {
t.Fatalf("lookup by uppercase alias failed: ok=%v def=%+v", ok, def)
}
}
+75
View File
@@ -0,0 +1,75 @@
package commands
import (
"context"
"strings"
)
type Handler func(ctx context.Context, req Request, rt *Runtime) error
type Request struct {
Channel string
ChatID string
SenderID string
Text string
Reply func(text string) error
}
const unavailableMsg = "Command unavailable in current context."
var commandPrefixes = []string{"/", "!"}
// parseCommandName accepts "/name", "!name", and Telegram's "/name@bot", then
// normalizes to lowercase command names.
func parseCommandName(input string) (string, bool) {
token := nthToken(input, 0)
if token == "" {
return "", false
}
name, ok := trimCommandPrefix(token)
if !ok {
return "", false
}
if i := strings.Index(name, "@"); i >= 0 {
name = name[:i]
}
name = normalizeCommandName(name)
if name == "" {
return "", false
}
return name, true
}
func trimCommandPrefix(token string) (string, bool) {
for _, prefix := range commandPrefixes {
if strings.HasPrefix(token, prefix) {
return strings.TrimPrefix(token, prefix), true
}
}
return "", false
}
// HasCommandPrefix returns true if the input starts with a recognized
// command prefix (e.g. "/" or "!").
func HasCommandPrefix(input string) bool {
token := nthToken(input, 0)
if token == "" {
return false
}
_, ok := trimCommandPrefix(token)
return ok
}
// nthToken returns the 0-indexed token from whitespace-split input.
func nthToken(input string, n int) string {
parts := strings.Fields(strings.TrimSpace(input))
if n >= len(parts) {
return ""
}
return parts[n]
}
func normalizeCommandName(name string) string {
return strings.ToLower(strings.TrimSpace(name))
}
+28
View File
@@ -0,0 +1,28 @@
package commands
import "testing"
func TestHasCommandPrefix(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"/help", true},
{"!help", true},
{"/switch model to gpt-4", true},
{"!switch model to gpt-4", true},
{"hello", false},
{"", false},
{" ", false},
{"hello /world", false},
{"/", true},
{"!", true},
{" /help", true},
}
for _, tt := range tests {
got := HasCommandPrefix(tt.input)
if got != tt.want {
t.Errorf("HasCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
+16
View File
@@ -0,0 +1,16 @@
package commands
import "github.com/sipeed/picoclaw/pkg/config"
// Runtime provides runtime dependencies to command handlers. It is constructed
// per-request by the agent loop so that per-request state (like session scope)
// can coexist with long-lived callbacks (like GetModelInfo).
type Runtime struct {
Config *config.Config
GetModelInfo func() (name, provider string)
ListAgentIDs func() []string
ListDefinitions func() []Definition
GetEnabledChannels func() []string
SwitchModel func(value string) (oldModel string, err error)
SwitchChannel func(value string) error
}
+85
View File
@@ -0,0 +1,85 @@
package commands
import (
"context"
"strings"
"testing"
)
func TestShowListHandlers_ChannelPolicy(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), nil)
var telegramReply string
handled := ex.Execute(context.Background(), Request{
Channel: "telegram",
Text: "/show channel",
Reply: func(text string) error {
telegramReply = text
return nil
},
})
if handled.Outcome != OutcomeHandled {
t.Fatalf("telegram /show outcome=%v, want=%v", handled.Outcome, OutcomeHandled)
}
if telegramReply != "Current Channel: telegram" {
t.Fatalf("telegram /show reply=%q, want=%q", telegramReply, "Current Channel: telegram")
}
var whatsappReply string
handledWhatsApp := ex.Execute(context.Background(), Request{
Channel: "whatsapp",
Text: "/show channel",
Reply: func(text string) error {
whatsappReply = text
return nil
},
})
if handledWhatsApp.Outcome != OutcomeHandled {
t.Fatalf("whatsapp /show outcome=%v, want=%v", handledWhatsApp.Outcome, OutcomeHandled)
}
if handledWhatsApp.Command != "show" {
t.Fatalf("whatsapp /show command=%q, want=%q", handledWhatsApp.Command, "show")
}
if whatsappReply != "Current Channel: whatsapp" {
t.Fatalf("whatsapp /show reply=%q, want=%q", whatsappReply, "Current Channel: whatsapp")
}
passthrough := ex.Execute(context.Background(), Request{
Channel: "whatsapp",
Text: "/foo",
})
if passthrough.Outcome != OutcomePassthrough {
t.Fatalf("whatsapp /foo outcome=%v, want=%v", passthrough.Outcome, OutcomePassthrough)
}
if passthrough.Command != "foo" {
t.Fatalf("whatsapp /foo command=%q, want=%q", passthrough.Command, "foo")
}
}
func TestShowListHandlers_ListHandledOnAllChannels(t *testing.T) {
rt := &Runtime{
GetEnabledChannels: func() []string {
return []string{"telegram"}
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Channel: "whatsapp",
Text: "/list channels",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("whatsapp /list outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if res.Command != "list" {
t.Fatalf("whatsapp /list command=%q, want=%q", res.Command, "list")
}
if !strings.Contains(reply, "telegram") {
t.Fatalf("whatsapp /list reply=%q, expected enabled channels content", reply)
}
}
+54 -15
View File
@@ -167,22 +167,35 @@ type SessionConfig struct {
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
}
// RoutingConfig controls the intelligent model routing feature.
// When enabled, each incoming message is scored against structural features
// (message length, code blocks, tool call history, conversation depth, attachments).
// Messages scoring below Threshold are sent to LightModel; all others use the
// agent's primary model. This reduces cost and latency for simple tasks without
// requiring any keyword matching — all scoring is language-agnostic.
type RoutingConfig struct {
Enabled bool `json:"enabled"`
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
@@ -218,6 +231,7 @@ type ChannelsConfig struct {
WeComApp WeComAppConfig `json:"wecom_app"`
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
Pico PicoConfig `json:"pico"`
IRC IRCConfig `json:"irc"`
}
// GroupTriggerConfig controls when the bot responds in group chats.
@@ -402,6 +416,25 @@ type PicoConfig struct {
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
}
type IRCConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_IRC_ENABLED"`
Server string `json:"server" env:"PICOCLAW_CHANNELS_IRC_SERVER"`
TLS bool `json:"tls" env:"PICOCLAW_CHANNELS_IRC_TLS"`
Nick string `json:"nick" env:"PICOCLAW_CHANNELS_IRC_NICK"`
User string `json:"user,omitempty" env:"PICOCLAW_CHANNELS_IRC_USER"`
RealName string `json:"real_name,omitempty" env:"PICOCLAW_CHANNELS_IRC_REAL_NAME"`
Password string `json:"password" env:"PICOCLAW_CHANNELS_IRC_PASSWORD"`
NickServPassword string `json:"nickserv_password" env:"PICOCLAW_CHANNELS_IRC_NICKSERV_PASSWORD"`
SASLUser string `json:"sasl_user" env:"PICOCLAW_CHANNELS_IRC_SASL_USER"`
SASLPassword string `json:"sasl_password" env:"PICOCLAW_CHANNELS_IRC_SASL_PASSWORD"`
Channels FlexibleStringSlice `json:"channels" env:"PICOCLAW_CHANNELS_IRC_CHANNELS"`
RequestCaps FlexibleStringSlice `json:"request_caps,omitempty" env:"PICOCLAW_CHANNELS_IRC_REQUEST_CAPS"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_IRC_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
Typing TypingConfig `json:"typing,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_IRC_REASONING_CHANNEL_ID"`
}
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
@@ -427,6 +460,7 @@ type ProvidersConfig struct {
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
Cerebras ProviderConfig `json:"cerebras"`
Vivgrid ProviderConfig `json:"vivgrid"`
VolcEngine ProviderConfig `json:"volcengine"`
GitHubCopilot ProviderConfig `json:"github_copilot"`
Antigravity ProviderConfig `json:"antigravity"`
@@ -452,6 +486,7 @@ func (p ProvidersConfig) IsEmpty() bool {
p.ShengSuanYun.APIKey == "" && p.ShengSuanYun.APIBase == "" &&
p.DeepSeek.APIKey == "" && p.DeepSeek.APIBase == "" &&
p.Cerebras.APIKey == "" && p.Cerebras.APIBase == "" &&
p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" &&
p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" &&
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
@@ -595,6 +630,7 @@ type ExecConfig struct {
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s)
}
type SkillsToolsConfig struct {
@@ -627,6 +663,7 @@ type ToolsConfig struct {
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
@@ -900,6 +937,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
return t.Subagent.Enabled
case "web_fetch":
return t.WebFetch.Enabled
case "send_file":
return t.SendFile.Enabled
case "write_file":
return t.WriteFile.Enabled
case "mcp":
+12
View File
@@ -261,6 +261,14 @@ func DefaultConfig() *Config {
APIKey: "",
},
// Vivgrid - https://vivgrid.com
{
ModelName: "vivgrid-auto",
Model: "vivgrid/auto",
APIBase: "https://api.vivgrid.com/v1",
APIKey: "",
},
// Volcengine (火山引擎) - https://console.volcengine.com/ark
{
ModelName: "doubao-pro",
@@ -386,6 +394,7 @@ func DefaultConfig() *Config {
Enabled: true,
},
EnableDenyPatterns: true,
TimeoutSeconds: 60,
},
Skills: SkillsToolsConfig{
ToolConfig: ToolConfig{
@@ -403,6 +412,9 @@ func DefaultConfig() *Config {
TTLSeconds: 300,
},
},
SendFile: ToolConfig{
Enabled: true,
},
MCP: MCPConfig{
ToolConfig: ToolConfig{
Enabled: false,
+17
View File
@@ -292,6 +292,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
{
providerNames: []string{"vivgrid"},
protocol: "vivgrid",
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
if p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" {
return ModelConfig{}, false
}
return ModelConfig{
ModelName: "vivgrid",
Model: "vivgrid/auto",
APIKey: p.Vivgrid.APIKey,
APIBase: p.Vivgrid.APIBase,
Proxy: p.Vivgrid.Proxy,
RequestTimeout: p.Vivgrid.RequestTimeout,
}, true
},
},
{
providerNames: []string{"volcengine", "doubao"},
protocol: "volcengine",
+5 -4
View File
@@ -155,7 +155,8 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
ShengSuanYun: ProviderConfig{APIKey: "key11"},
DeepSeek: ProviderConfig{APIKey: "key12"},
Cerebras: ProviderConfig{APIKey: "key13"},
VolcEngine: ProviderConfig{APIKey: "key14"},
Vivgrid: ProviderConfig{APIKey: "key14"},
VolcEngine: ProviderConfig{APIKey: "key15"},
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
Antigravity: ProviderConfig{AuthMethod: "oauth"},
Qwen: ProviderConfig{APIKey: "key17"},
@@ -166,9 +167,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
result := ConvertProvidersToModelList(cfg)
// All 20 providers should be converted
if len(result) != 20 {
t.Errorf("len(result) = %d, want 20", len(result))
// All 21 providers should be converted
if len(result) != 21 {
t.Errorf("len(result) = %d, want 21", len(result))
}
}
+20
View File
@@ -190,14 +190,21 @@ func (cs *CronService) executeJobByID(jobID string) {
cs.mu.RUnlock()
if callbackJob == nil {
log.Printf("[cron] job %s not found, skipping", jobID)
return
}
// Log job execution start
log.Printf("[cron] ▶ executing job '%s' (id: %s, schedule: %s, channel: %s)",
callbackJob.Name, jobID, callbackJob.Schedule.Kind, callbackJob.Payload.Channel)
var err error
if cs.onJob != nil {
_, err = cs.onJob(callbackJob)
}
execDuration := time.Now().UnixMilli() - startTime
// Now acquire lock to update state
cs.mu.Lock()
defer cs.mu.Unlock()
@@ -220,22 +227,35 @@ func (cs *CronService) executeJobByID(jobID string) {
if err != nil {
job.State.LastStatus = "error"
job.State.LastError = err.Error()
log.Printf("[cron] ✗ job '%s' failed after %dms: %v", job.Name, execDuration, err)
} else {
job.State.LastStatus = "ok"
job.State.LastError = ""
}
// Compute next run time
var nextRunStr string
if job.Schedule.Kind == "at" {
if job.DeleteAfterRun {
cs.removeJobUnsafe(job.ID)
nextRunStr = "(deleted)"
} else {
job.Enabled = false
job.State.NextRunAtMS = nil
nextRunStr = "(disabled)"
}
} else {
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
job.State.NextRunAtMS = nextRun
if nextRun != nil {
nextRunStr = time.UnixMilli(*nextRun).Format("2006-01-02 15:04:05")
} else {
nextRunStr = "(none)"
}
}
if err == nil {
log.Printf("[cron] ✓ job '%s' completed in %dms, next run: %s", job.Name, execDuration, nextRunStr)
}
if err := cs.saveStoreUnsafe(); err != nil {
+50 -4
View File
@@ -23,7 +23,10 @@ type (
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
)
const defaultBaseURL = "https://api.anthropic.com"
const (
defaultBaseURL = "https://api.anthropic.com"
anthropicBetaHeader = "oauth-2025-04-20"
)
type Provider struct {
client *anthropic.Client
@@ -80,7 +83,10 @@ func (p *Provider) Chat(
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
opts = append(opts,
option.WithAuthToken(tok),
option.WithHeader("anthropic-beta", anthropicBetaHeader),
)
}
params, err := buildParams(messages, tools, model, options)
@@ -88,6 +94,11 @@ func (p *Provider) Chat(
return nil, err
}
// OAuth/setup-tokens require streaming; API keys use non-streaming.
if p.tokenSource != nil {
return p.chatStreaming(ctx, params, opts)
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
@@ -96,6 +107,28 @@ func (p *Provider) Chat(
return parseResponse(resp), nil
}
func (p *Provider) chatStreaming(
ctx context.Context,
params anthropic.MessageNewParams,
opts []option.RequestOption,
) (*LLMResponse, error) {
stream := p.client.Messages.NewStreaming(ctx, params, opts...)
defer stream.Close()
var msg anthropic.Message
for stream.Next() {
event := stream.Current()
if err := msg.Accumulate(event); err != nil {
return nil, fmt.Errorf("claude streaming accumulate: %w", err)
}
}
if err := stream.Err(); err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseResponse(&msg), nil
}
func (p *Provider) GetDefaultModel() string {
return "claude-sonnet-4.6"
}
@@ -147,7 +180,16 @@ func buildParams(
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
args := tc.Arguments
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
args = map[string]any{}
}
}
if args == nil {
args = map[string]any{}
}
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, args, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
@@ -167,8 +209,12 @@ func buildParams(
maxTokens = int64(mt)
}
// Normalize model ID: Anthropic API uses hyphens (claude-sonnet-4-6),
// but config may use dots (claude-sonnet-4.6).
apiModel := strings.ReplaceAll(model, ".", "-")
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Model: anthropic.Model(apiModel),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
+61 -2
View File
@@ -21,8 +21,8 @@ func TestBuildParams_BasicMessage(t *testing.T) {
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4.6" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4.6")
if string(params.Model) != "claude-sonnet-4-6" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-6")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
@@ -262,6 +262,65 @@ func TestProvider_ChatUsesTokenSource(t *testing.T) {
}
}
func TestProvider_ChatStreamingRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
t.Errorf("Authorization = %q, want %q", got, "Bearer refreshed-token")
}
if got := r.Header.Get("Anthropic-Beta"); got != anthropicBetaHeader {
t.Errorf("Anthropic-Beta = %q, want %q", got, anthropicBetaHeader)
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
events := []string{
"event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":12,\"output_tokens\":0}}}\n\n",
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n",
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\n",
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
}
for _, e := range events {
w.Write([]byte(e))
if flusher != nil {
flusher.Flush()
}
}
}))
defer server.Close()
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
return "refreshed-token", nil
}, server.URL)
resp, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "Hello"}},
nil,
"claude-sonnet-4.6",
map[string]any{},
)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hello world" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello world")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.CompletionTokens != 5 {
t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens)
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
+16
View File
@@ -153,6 +153,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
}
case "vivgrid":
if cfg.Providers.Vivgrid.APIKey != "" {
sel.apiKey = cfg.Providers.Vivgrid.APIKey
sel.apiBase = cfg.Providers.Vivgrid.APIBase
sel.proxy = cfg.Providers.Vivgrid.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.vivgrid.com/v1"
}
}
case "claude-cli", "claude-code", "claudecode":
workspace := cfg.WorkspacePath()
if workspace == "" {
@@ -295,6 +304,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
case strings.HasPrefix(model, "vivgrid/") && cfg.Providers.Vivgrid.APIKey != "":
sel.apiKey = cfg.Providers.Vivgrid.APIKey
sel.apiBase = cfg.Providers.Vivgrid.APIBase
sel.proxy = cfg.Providers.Vivgrid.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.vivgrid.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
sel.apiKey = cfg.Providers.Ollama.APIKey
sel.apiBase = cfg.Providers.Ollama.APIBase
+3 -1
View File
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"volcengine", "vllm", "qwen", "mistral", "avian":
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -200,6 +200,8 @@ func getDefaultAPIBase(protocol string) string {
return "https://api.deepseek.com/v1"
case "cerebras":
return "https://api.cerebras.ai/v1"
case "vivgrid":
return "https://api.vivgrid.com/v1"
case "volcengine":
return "https://ark.cn-beijing.volces.com/api/v3"
case "qwen":
+1
View File
@@ -108,6 +108,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
{"groq", "groq"},
{"openrouter", "openrouter"},
{"cerebras", "cerebras"},
{"vivgrid", "vivgrid"},
{"qwen", "qwen"},
{"vllm", "vllm"},
{"deepseek", "deepseek"},
+11
View File
@@ -88,6 +88,17 @@ func TestResolveProviderSelection(t *testing.T) {
wantAPIBase: "https://integrate.api.nvidia.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit vivgrid provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "vivgrid"
cfg.Providers.Vivgrid.APIKey = "vivgrid-key"
cfg.Providers.Vivgrid.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.vivgrid.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "openrouter model uses openrouter defaults",
setup: func(cfg *config.Config) {
+87 -10
View File
@@ -1,6 +1,7 @@
package openai_compat
import (
"bufio"
"bytes"
"context"
"encoding/json"
@@ -183,19 +184,94 @@ func (p *Provider) Chat(
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
contentType := resp.Header.Get("Content-Type")
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
if readErr != nil {
return nil, fmt.Errorf("failed to read response: %w", readErr)
}
if looksLikeHTML(body, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
}
return nil, fmt.Errorf(
"API request failed:\n Status: %d\n Body: %s",
resp.StatusCode,
responsePreview(body, 128),
)
}
return parseResponse(body)
// Peek without consuming so the full stream reaches the JSON decoder.
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if looksLikeHTML(prefix, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
}
out, err := parseResponse(reader)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return out, nil
}
func parseResponse(body []byte) (*LLMResponse, error) {
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := responsePreview(body, 128)
return fmt.Errorf(
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
apiBase,
contentType,
statusCode,
respPreview,
)
}
func looksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}
func responsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return "<empty>"
}
if len(trimmed) <= maxLen {
return string(trimmed)
}
return string(trimmed[:maxLen]) + "..."
}
func parseResponse(body io.Reader) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
@@ -222,8 +298,8 @@ func parseResponse(body []byte) (*LLMResponse, error) {
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(apiResponse.Choices) == 0 {
@@ -363,7 +439,8 @@ func normalizeModel(model, apiBase string) string {
prefix := strings.ToLower(before)
switch prefix {
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google",
"openrouter", "zhipu", "mistral", "vivgrid":
return after
default:
return model
+175 -1
View File
@@ -1,7 +1,10 @@
package openai_compat
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
@@ -212,6 +215,132 @@ func TestProviderChat_HTTPError(t *testing.T) {
}
}
func TestProviderChat_JSONHTTPErrorDoesNotReportHTML(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"bad request"}`))
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "Status: 400") {
t.Fatalf("expected status code in error, got %v", err)
}
if strings.Contains(err.Error(), "returned HTML instead of JSON") {
t.Fatalf("expected non-HTML http error, got %v", err)
}
}
func TestProviderChat_HTMLResponsesReturnHelpfulError(t *testing.T) {
tests := []struct {
name string
contentType string
statusCode int
body string
}{
{
name: "html success response",
contentType: "text/html; charset=utf-8",
statusCode: http.StatusOK,
body: "<!DOCTYPE html><html><body>gateway login</body></html>",
},
{
name: "html error response",
contentType: "text/html; charset=utf-8",
statusCode: http.StatusBadGateway,
body: "<!DOCTYPE html><html><body>bad gateway</body></html>",
},
{
name: "mislabeled html success response",
contentType: "application/json",
statusCode: http.StatusOK,
body: " \r\n\t<!DOCTYPE html><html><body>gateway login</body></html>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", tt.contentType)
w.WriteHeader(tt.statusCode)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), fmt.Sprintf("Status: %d", tt.statusCode)) {
t.Fatalf("expected status code in error, got %v", err)
}
if !strings.Contains(err.Error(), "returned HTML instead of JSON") {
t.Fatalf("expected helpful HTML error, got %v", err)
}
if !strings.Contains(err.Error(), "check api_base or proxy configuration") {
t.Fatalf("expected configuration hint, got %v", err)
}
})
}
}
func TestProviderChat_SuccessResponseUsesStreamingDecoder(t *testing.T) {
content := strings.Repeat("a", 1024)
body := `{"choices":[{"message":{"content":"` + content + `"},"finish_reason":"stop"}]}`
p := NewProvider("key", "https://example.com/v1", "")
p.httpClient = &http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: &errAfterDataReadCloser{
data: []byte(body),
chunkSize: 64,
},
}, nil
}),
}
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if out.Content != content {
t.Fatalf("Content = %q, want %q", out.Content, content)
}
}
func TestProviderChat_LargeHTMLResponsePreviewIsTruncated(t *testing.T) {
body := append([]byte("<!DOCTYPE html><html><body>"), bytes.Repeat([]byte("A"), 2048)...)
body = append(body, []byte("</body></html>")...)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write(body)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "Body: <!DOCTYPE html><html><body>") {
t.Fatalf("expected html preview in error, got %v", err)
}
if !strings.Contains(err.Error(), "...") {
t.Fatalf("expected truncated preview, got %v", err)
}
}
func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
var requestBody map[string]any
@@ -253,7 +382,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
}
}
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
tests := []struct {
name string
input string
@@ -279,6 +408,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
input: "deepseek/deepseek-chat",
wantModel: "deepseek-chat",
},
{
name: "strips vivgrid prefix",
input: "vivgrid/auto",
wantModel: "auto",
},
}
for _, tt := range tests {
@@ -383,6 +517,12 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
if got := normalizeModel("vivgrid/managed", "https://api.vivgrid.com/v1"); got != "managed" {
t.Fatalf("normalizeModel(vivgrid) = %q, want %q", got, "managed")
}
if got := normalizeModel("vivgrid/auto", "https://api.vivgrid.com/v1"); got != "auto" {
t.Fatalf("normalizeModel(vivgrid auto) = %q, want %q", got, "auto")
}
}
func TestProvider_RequestTimeoutDefault(t *testing.T) {
@@ -399,6 +539,40 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) {
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
type errAfterDataReadCloser struct {
data []byte
chunkSize int
offset int
}
func (r *errAfterDataReadCloser) Read(p []byte) (int, error) {
if r.offset >= len(r.data) {
return 0, io.ErrUnexpectedEOF
}
n := r.chunkSize
if n <= 0 || n > len(p) {
n = len(p)
}
remaining := len(r.data) - r.offset
if n > remaining {
n = remaining
}
copy(p, r.data[r.offset:r.offset+n])
r.offset += n
return n, nil
}
func (r *errAfterDataReadCloser) Close() error {
return nil
}
func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) {
p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens"))
if p.maxTokensField != "max_completion_tokens" {
+80
View File
@@ -0,0 +1,80 @@
package routing
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
// A higher score indicates a more complex task that benefits from a heavy model.
// The score is compared against the configured threshold: score >= threshold selects
// the primary (heavy) model; score < threshold selects the light model.
//
// Classifier is an interface so that future implementations (ML-based, embedding-based,
// or any other approach) can be swapped in without changing routing infrastructure.
type Classifier interface {
Score(f Features) float64
}
// RuleClassifier is the v1 implementation.
// It uses a weighted sum of structural signals with no external dependencies,
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
// that the returned score always falls within the [0, 1] contract.
//
// Individual weights (multiple signals can fire simultaneously):
//
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
// token 50-200: 0.15 — medium length; may or may not be complex
// code block present: 0.40 — coding tasks need the heavy model
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
// tool calls 1-3 (recent): 0.10 — some tool activity
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
//
// Default threshold is 0.35, so:
// - Pure greetings / trivial Q&A: 0.00 → light ✓
// - Medium prose message (50200 tokens): 0.15 → light ✓
// - Message with code block: 0.40 → heavy ✓
// - Long message (>200 tokens): 0.35 → heavy ✓
// - Active tool session + medium message: 0.25 → light (acceptable)
// - Any message with an image/audio attachment: 1.00 → heavy ✓
type RuleClassifier struct{}
// Score computes the complexity score for the given feature set.
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
func (c *RuleClassifier) Score(f Features) float64 {
// Hard gate: multi-modal inputs always require the heavy model.
if f.HasAttachments {
return 1.0
}
var score float64
// Token estimate — primary verbosity signal
switch {
case f.TokenEstimate > 200:
score += 0.35
case f.TokenEstimate > 50:
score += 0.15
}
// Fenced code blocks — strongest indicator of a coding/technical task
if f.CodeBlockCount > 0 {
score += 0.40
}
// Recent tool call density — indicates an ongoing agentic workflow
switch {
case f.RecentToolCalls > 3:
score += 0.25
case f.RecentToolCalls > 0:
score += 0.10
}
// Conversation depth — accumulated context implies compound task
if f.ConversationDepth > 10 {
score += 0.10
}
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
if score > 1.0 {
score = 1.0
}
return score
}
+127
View File
@@ -0,0 +1,127 @@
package routing
import (
"strings"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// lookbackWindow is the number of recent history entries scanned for tool calls.
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
const lookbackWindow = 6
// Features holds the structural signals extracted from a message and its session context.
// Every dimension is language-agnostic by construction — no keyword or pattern matching
// against natural-language content. This ensures consistent routing for all locales.
type Features struct {
// TokenEstimate is a proxy for token count.
// CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each.
// This avoids API calls while giving accurate estimates for all scripts.
TokenEstimate int
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
// Coding tasks almost always require the heavy model.
CodeBlockCount int
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
// history entries. A high density indicates an active agentic workflow.
RecentToolCalls int
// ConversationDepth is the total number of messages in the session history.
// Deep sessions tend to carry implicit complexity built up over many turns.
ConversationDepth int
// HasAttachments is true when the message appears to contain media (images,
// audio, video). Multi-modal inputs require vision-capable heavy models.
HasAttachments bool
}
// ExtractFeatures computes the structural feature vector for a message.
// It is a pure function with no side effects and zero allocations beyond
// the returned struct.
func ExtractFeatures(msg string, history []providers.Message) Features {
return Features{
TokenEstimate: estimateTokens(msg),
CodeBlockCount: countCodeBlocks(msg),
RecentToolCalls: countRecentToolCalls(history),
ConversationDepth: len(history),
HasAttachments: hasAttachments(msg),
}
}
// estimateTokens returns a token count proxy that handles both CJK and Latin text.
// CJK runes (U+2E80U+9FFF, U+F900U+FAFF, U+AC00U+D7AF) map to roughly one
// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token
// for English). Splitting the count this way avoids the 3x underestimation that a
// flat rune_count/3 would produce for Chinese, Japanese, and Korean text.
func estimateTokens(msg string) int {
total := utf8.RuneCountInString(msg)
if total == 0 {
return 0
}
cjk := 0
for _, r := range msg {
if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF {
cjk++
}
}
return cjk + (total-cjk)/4
}
// countCodeBlocks counts the number of complete fenced code blocks.
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
// An unclosed opening fence (odd count) is treated as zero complete blocks
// since it may just be an inline code span or a typo.
func countCodeBlocks(msg string) int {
n := strings.Count(msg, "```")
return n / 2
}
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
// entries of history. It examines the ToolCalls field rather than parsing
// the content string, so it is robust to any message format.
func countRecentToolCalls(history []providers.Message) int {
start := len(history) - lookbackWindow
if start < 0 {
start = 0
}
count := 0
for _, msg := range history[start:] {
if len(msg.ToolCalls) > 0 {
count += len(msg.ToolCalls)
}
}
return count
}
// hasAttachments returns true when the message content contains embedded media.
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
// common image/audio URL extensions. This is intentionally conservative —
// false negatives (missing an attachment) just mean the routing falls back to
// the primary model anyway.
func hasAttachments(msg string) bool {
lower := strings.ToLower(msg)
// Base64 data URIs embedded directly in the message
if strings.Contains(lower, "data:image/") ||
strings.Contains(lower, "data:audio/") ||
strings.Contains(lower, "data:video/") {
return true
}
// Common image/audio extensions in URLs or file references
mediaExts := []string{
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
".mp3", ".wav", ".ogg", ".m4a", ".flac",
".mp4", ".avi", ".mov", ".webm",
}
for _, ext := range mediaExts {
if strings.Contains(lower, ext) {
return true
}
}
return false
}
+82
View File
@@ -0,0 +1,82 @@
package routing
import (
"github.com/sipeed/picoclaw/pkg/providers"
)
// defaultThreshold is used when the config threshold is zero or negative.
// At 0.35 a message needs at least one strong signal (code block, long text,
// or an attachment) before the heavy model is chosen.
const defaultThreshold = 0.35
// RouterConfig holds the validated model routing settings.
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
type RouterConfig struct {
// LightModel is the model_name (from model_list) used for simple tasks.
LightModel string
// Threshold is the complexity score cutoff in [0, 1].
// score >= Threshold → primary (heavy) model.
// score < Threshold → light model.
Threshold float64
}
// Router selects the appropriate model tier for each incoming message.
// It is safe for concurrent use from multiple goroutines.
type Router struct {
cfg RouterConfig
classifier Classifier
}
// New creates a Router with the given config and the default RuleClassifier.
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
func New(cfg RouterConfig) *Router {
if cfg.Threshold <= 0 {
cfg.Threshold = defaultThreshold
}
return &Router{
cfg: cfg,
classifier: &RuleClassifier{},
}
}
// newWithClassifier creates a Router with a custom Classifier.
// Intended for unit tests that need to inject a deterministic scorer.
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
if cfg.Threshold <= 0 {
cfg.Threshold = defaultThreshold
}
return &Router{cfg: cfg, classifier: c}
}
// SelectModel returns the model to use for this conversation turn along with
// the computed complexity score (for logging and debugging).
//
// - If score < cfg.Threshold: returns (cfg.LightModel, true, score)
// - Otherwise: returns (primaryModel, false, score)
//
// The caller is responsible for resolving the returned model name into
// provider candidates (see AgentInstance.LightCandidates).
func (r *Router) SelectModel(
msg string,
history []providers.Message,
primaryModel string,
) (model string, usedLight bool, score float64) {
features := ExtractFeatures(msg, history)
score = r.classifier.Score(features)
if score < r.cfg.Threshold {
return r.cfg.LightModel, true, score
}
return primaryModel, false, score
}
// LightModel returns the configured light model name.
func (r *Router) LightModel() string {
return r.cfg.LightModel
}
// Threshold returns the complexity threshold in use.
func (r *Router) Threshold() float64 {
return r.cfg.Threshold
}
+414
View File
@@ -0,0 +1,414 @@
package routing
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// ── ExtractFeatures ──────────────────────────────────────────────────────────
func TestExtractFeatures_EmptyMessage(t *testing.T) {
f := ExtractFeatures("", nil)
if f.TokenEstimate != 0 {
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
}
if f.CodeBlockCount != 0 {
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
}
if f.RecentToolCalls != 0 {
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
}
if f.ConversationDepth != 0 {
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
}
if f.HasAttachments {
t.Error("HasAttachments: got true, want false")
}
}
func TestExtractFeatures_TokenEstimate(t *testing.T) {
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
msg := strings.Repeat("a", 30)
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 7 {
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
// Using a rune slice literal avoids CJK string literals in source.
msg := string([]rune{
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60,
})
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 9 {
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 6 {
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
}
}
func TestExtractFeatures_CodeBlocks(t *testing.T) {
cases := []struct {
msg string
want int
}{
{"no code here", 0},
{"```go\nfmt.Println()\n```", 1},
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.CodeBlockCount != tc.want {
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
}
}
}
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
// History longer than lookbackWindow — only last lookbackWindow entries count.
history := make([]providers.Message, 10)
// Put 2 tool calls at positions 8 and 9 (within the last 6)
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
history[9] = providers.Message{
Role: "assistant",
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
}
// Position 3 is outside the lookback window and must NOT be counted
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
f := ExtractFeatures("test", history)
// 1 (position 8) + 2 (position 9) = 3
if f.RecentToolCalls != 3 {
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
}
}
func TestExtractFeatures_ConversationDepth(t *testing.T) {
history := make([]providers.Message, 7)
f := ExtractFeatures("msg", history)
if f.ConversationDepth != 7 {
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
}
}
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"plain text", false},
{"here is an image: data:image/png;base64,abc123", true},
{"audio: data:audio/mp3;base64,xyz", true},
{"video: data:video/mp4;base64,xyz", true},
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"check out photo.jpg", true},
{"see screenshot.png", true},
{"listen to audio.mp3", true},
{"watch clip.mp4", true},
{"just a .go file", false},
{"document.pdf", false}, // pdf is not in the media list
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
// ── RuleClassifier ───────────────────────────────────────────────────────────
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{})
if score != 0.0 {
t.Errorf("zero features: got %f, want 0.0", score)
}
}
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{HasAttachments: true})
if score != 1.0 {
t.Errorf("attachments: got %f, want 1.0", score)
}
}
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
c := &RuleClassifier{}
// Code block alone = 0.40, above default threshold 0.35
score := c.Score(Features{CodeBlockCount: 1})
if score < 0.35 {
t.Errorf("code block: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_LongMessage(t *testing.T) {
c := &RuleClassifier{}
// >200 tokens = 0.35, exactly at default threshold → heavy
score := c.Score(Features{TokenEstimate: 250})
if score < 0.35 {
t.Errorf("long message: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_MediumMessage(t *testing.T) {
c := &RuleClassifier{}
// 50-200 tokens = 0.15, below threshold → light
score := c.Score(Features{TokenEstimate: 100})
if score >= 0.35 {
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
}
}
func TestRuleClassifier_ShortMessage(t *testing.T) {
c := &RuleClassifier{}
// <50 tokens, no other signals = 0.0 → light
score := c.Score(Features{TokenEstimate: 10})
if score != 0.0 {
t.Errorf("short message: got %f, want 0.0", score)
}
}
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
c := &RuleClassifier{}
scoreNone := c.Score(Features{RecentToolCalls: 0})
scoreLow := c.Score(Features{RecentToolCalls: 2})
scoreHigh := c.Score(Features{RecentToolCalls: 5})
if scoreNone != 0.0 {
t.Errorf("no tools: got %f, want 0.0", scoreNone)
}
if scoreLow <= scoreNone {
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
}
if scoreHigh <= scoreLow {
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
}
}
func TestRuleClassifier_DeepConversation(t *testing.T) {
c := &RuleClassifier{}
shallow := c.Score(Features{ConversationDepth: 5})
deep := c.Score(Features{ConversationDepth: 15})
if deep <= shallow {
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
}
}
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
c := &RuleClassifier{}
// Max all signals simultaneously
f := Features{
TokenEstimate: 500,
CodeBlockCount: 3,
RecentToolCalls: 10,
ConversationDepth: 20,
}
score := c.Score(f)
if score > 1.0 {
t.Errorf("score %f exceeds 1.0", score)
}
}
// ── Router ───────────────────────────────────────────────────────────────────
func TestRouter_DefaultThreshold(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash"})
if r.Threshold() != defaultThreshold {
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
if r.Threshold() != defaultThreshold {
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "hi"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("simple message: expected light model to be selected")
}
if model != "gemini-flash" {
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
}
}
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "```go\nfmt.Println(\"hello\")\n```"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("code block: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "can you analyze this? data:image/png;base64,abc123"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("attachment: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
// >200 token estimate: 210 * 3 = 630 chars
msg := strings.Repeat("word ", 210)
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("long message: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
// Routing is conservative: only promote to heavy when the signal is unambiguous.
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
}
msg := "ok"
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
if !usedLight {
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
}
}
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
}},
}
// ~55 tokens * 3 = 165 chars
msg := strings.Repeat("word ", 55)
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
if usedLight {
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
}
}
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
// Very low threshold: even a short message triggers heavy model
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("low threshold: medium message should use primary model")
}
}
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
// Very high threshold: even code blocks route to light
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
msg := "```go\nfmt.Println()\n```"
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("very high threshold: code block (0.40) should route to light model")
}
}
func TestRouter_LightModel(t *testing.T) {
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
if r.LightModel() != "my-fast-model" {
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
}
}
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
type fixedScoreClassifier struct{ score float64 }
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.2},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if !usedLight {
t.Error("low score with custom classifier: expected light model")
}
}
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.8},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("high score with custom classifier: expected primary model")
}
}
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
// score == threshold → primary (uses >= comparison)
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.5},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("score == threshold: expected primary model (>= threshold → primary)")
}
}
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.42},
)
_, _, score := r.SelectModel("anything", nil, "heavy")
if score != 0.42 {
t.Errorf("score: got %f, want 0.42", score)
}
}
+6
View File
@@ -141,6 +141,12 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
everySeconds, hasEvery := args["every_seconds"].(float64)
cronExpr, hasCron := args["cron_expr"].(string)
// Fix: type assertions return true for zero values, need additional validity checks
// This prevents LLMs that fill unused optional parameters with defaults (0) from triggering wrong type
hasAt = hasAt && atSeconds > 0
hasEvery = hasEvery && everySeconds > 0
hasCron = hasCron && cronExpr != ""
// Priority: at_seconds > every_seconds > cron_expr
if hasAt {
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
+150
View File
@@ -0,0 +1,150 @@
package tools
import (
"context"
"fmt"
"mime"
"os"
"path/filepath"
"strings"
"github.com/h2non/filetype"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
// SendFileTool allows the LLM to send a local file (image, document, etc.)
// to the user on the current chat channel via the MediaStore pipeline.
type SendFileTool struct {
workspace string
restrict bool
maxFileSize int
mediaStore media.MediaStore
defaultChannel string
defaultChatID string
}
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
if maxFileSize <= 0 {
maxFileSize = config.DefaultMaxMediaSize
}
return &SendFileTool{
workspace: workspace,
restrict: restrict,
maxFileSize: maxFileSize,
mediaStore: store,
}
}
func (t *SendFileTool) Name() string { return "send_file" }
func (t *SendFileTool) Description() string {
return "Send a local file (image, document, etc.) to the user on the current chat channel."
}
func (t *SendFileTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the local file. Relative paths are resolved from workspace.",
},
"filename": map[string]any{
"type": "string",
"description": "Optional display filename. Defaults to the basename of path.",
},
},
"required": []string{"path"},
}
}
func (t *SendFileTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
}
func (t *SendFileTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, _ := args["path"].(string)
if strings.TrimSpace(path) == "" {
return ErrorResult("path is required")
}
// Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values.
channel := ToolChannel(ctx)
if channel == "" {
channel = t.defaultChannel
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = t.defaultChatID
}
if channel == "" || chatID == "" {
return ErrorResult("no target channel/chat available")
}
if t.mediaStore == nil {
return ErrorResult("media store not configured")
}
resolved, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
}
info, err := os.Stat(resolved)
if err != nil {
return ErrorResult(fmt.Sprintf("file not found: %v", err))
}
if info.IsDir() {
return ErrorResult("path is a directory, expected a file")
}
if info.Size() > int64(t.maxFileSize) {
return ErrorResult(fmt.Sprintf(
"file too large: %d bytes (max %d bytes)",
info.Size(), t.maxFileSize,
))
}
filename, _ := args["filename"].(string)
if filename == "" {
filename = filepath.Base(resolved)
}
mediaType := detectMediaType(resolved)
scope := fmt.Sprintf("tool:send_file:%s:%s", channel, chatID)
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
Filename: filename,
ContentType: mediaType,
Source: "tool:send_file",
}, scope)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to register media: %v", err))
}
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref})
}
// detectMediaType determines the MIME type of a file.
// Uses magic-bytes detection (h2non/filetype) first, then falls back to
// extension-based lookup via mime.TypeByExtension.
func detectMediaType(path string) string {
kind, err := filetype.MatchFile(path)
if err == nil && kind != filetype.Unknown {
return kind.MIME.Value
}
if ext := filepath.Ext(path); ext != "" {
if t := mime.TypeByExtension(ext); t != "" {
return t
}
}
return "application/octet-stream"
}
+176
View File
@@ -0,0 +1,176 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestSendFileTool_MissingPath(t *testing.T) {
store := media.NewFileMediaStore()
tool := NewSendFileTool("/tmp", false, 0, store)
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{})
if !result.IsError {
t.Fatal("expected error for missing path")
}
}
func TestSendFileTool_NoContext(t *testing.T) {
store := media.NewFileMediaStore()
tool := NewSendFileTool("/tmp", false, 0, store)
// no SetContext call
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp/test.txt"})
if !result.IsError {
t.Fatal("expected error when no channel context")
}
}
func TestSendFileTool_NoMediaStore(t *testing.T) {
tool := NewSendFileTool("/tmp", false, 0, nil)
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp/test.txt"})
if !result.IsError {
t.Fatal("expected error when no media store")
}
}
func TestSendFileTool_Directory(t *testing.T) {
store := media.NewFileMediaStore()
tool := NewSendFileTool("/tmp", false, 0, store)
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp"})
if !result.IsError {
t.Fatal("expected error for directory path")
}
}
func TestSendFileTool_FileTooLarge(t *testing.T) {
dir := t.TempDir()
testFile := filepath.Join(dir, "big.bin")
// Create a file larger than the limit
if err := os.WriteFile(testFile, make([]byte, 1024), 0o644); err != nil {
t.Fatal(err)
}
store := media.NewFileMediaStore()
tool := NewSendFileTool(dir, false, 512, store) // 512 byte limit
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": testFile})
if !result.IsError {
t.Fatal("expected error for oversized file")
}
if !strings.Contains(result.ForLLM, "too large") {
t.Errorf("expected 'too large' in error, got %q", result.ForLLM)
}
}
func TestSendFileTool_DefaultMaxSize(t *testing.T) {
tool := NewSendFileTool("/tmp", false, 0, nil)
if tool.maxFileSize != config.DefaultMaxMediaSize {
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
}
}
func TestSendFileTool_Success(t *testing.T) {
dir := t.TempDir()
testFile := filepath.Join(dir, "photo.png")
if err := os.WriteFile(testFile, []byte("fake png"), 0o644); err != nil {
t.Fatal(err)
}
store := media.NewFileMediaStore()
tool := NewSendFileTool(dir, false, 0, store)
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": testFile})
if result.IsError {
t.Fatalf("unexpected error: %s", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
if result.Media[0][:8] != "media://" {
t.Errorf("expected media:// ref, got %q", result.Media[0])
}
}
func TestSendFileTool_CustomFilename(t *testing.T) {
dir := t.TempDir()
testFile := filepath.Join(dir, "img.jpg")
if err := os.WriteFile(testFile, []byte("fake jpg"), 0o644); err != nil {
t.Fatal(err)
}
store := media.NewFileMediaStore()
tool := NewSendFileTool(dir, false, 0, store)
tool.SetContext("telegram", "chat456")
result := tool.Execute(context.Background(), map[string]any{
"path": testFile,
"filename": "my-photo.jpg",
})
if result.IsError {
t.Fatalf("unexpected error: %s", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
}
func TestDetectMediaType_MagicBytes(t *testing.T) {
dir := t.TempDir()
// Minimal valid PNG header
pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
pngFile := filepath.Join(dir, "image.dat") // wrong extension, but valid PNG bytes
if err := os.WriteFile(pngFile, pngHeader, 0o644); err != nil {
t.Fatal(err)
}
got := detectMediaType(pngFile)
if got != "image/png" {
t.Errorf("expected image/png from magic bytes, got %q", got)
}
}
func TestDetectMediaType_FallbackToExtension(t *testing.T) {
dir := t.TempDir()
// File with unrecognizable content but known extension
txtFile := filepath.Join(dir, "readme.txt")
if err := os.WriteFile(txtFile, []byte("hello world"), 0o644); err != nil {
t.Fatal(err)
}
got := detectMediaType(txtFile)
// text/plain or similar — just verify it's not application/octet-stream
if got == "application/octet-stream" {
t.Errorf("expected extension-based MIME for .txt, got %q", got)
}
}
func TestDetectMediaType_UnknownFallsToOctetStream(t *testing.T) {
dir := t.TempDir()
// File with no extension and random bytes
unknownFile := filepath.Join(dir, "mystery")
if err := os.WriteFile(unknownFile, []byte{0x00, 0x01, 0x02}, 0o644); err != nil {
t.Fatal(err)
}
got := detectMediaType(unknownFile)
if got != "application/octet-stream" {
t.Errorf("expected application/octet-stream, got %q", got)
}
}
+7 -2
View File
@@ -59,7 +59,7 @@ var (
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bkill\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
@@ -131,9 +131,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
timeout := 60 * time.Second
if config != nil && config.Tools.Exec.TimeoutSeconds > 0 {
timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second
}
return &ExecTool{
workingDir: workingDir,
timeout: 60 * time.Second,
timeout: timeout,
denyPatterns: denyPatterns,
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
+20
View File
@@ -151,6 +151,26 @@ func TestShellTool_DangerousCommand(t *testing.T) {
}
}
func TestShellTool_DangerousCommand_KillBlocked(t *testing.T) {
tool, err := NewExecTool("", false)
if err != nil {
t.Errorf("unable to configure exec tool: %s", err)
}
ctx := context.Background()
args := map[string]any{
"command": "kill 12345",
}
result := tool.Execute(ctx, args)
if !result.IsError {
t.Errorf("Expected kill command to be blocked")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected blocked message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
tool, err := NewExecTool("", false)
+2 -2
View File
@@ -8,7 +8,7 @@ import (
func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnTool(manager)
ctx := context.Background()
@@ -42,7 +42,7 @@ func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
func TestSpawnTool_Execute_ValidTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnTool(manager)
ctx := context.Background()
-18
View File
@@ -6,7 +6,6 @@ import (
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -27,7 +26,6 @@ type SubagentManager struct {
mu sync.RWMutex
provider providers.LLMProvider
defaultModel string
bus *bus.MessageBus
workspace string
tools *ToolRegistry
maxIterations int
@@ -41,13 +39,11 @@ type SubagentManager struct {
func NewSubagentManager(
provider providers.LLMProvider,
defaultModel, workspace string,
bus *bus.MessageBus,
) *SubagentManager {
return &SubagentManager{
tasks: make(map[string]*SubagentTask),
provider: provider,
defaultModel: defaultModel,
bus: bus,
workspace: workspace,
tools: NewToolRegistry(),
maxIterations: 10,
@@ -214,20 +210,6 @@ After completing the task, provide a clear summary of what was done.`
Async: false,
}
}
// Send announce message back to main agent
if sm.bus != nil {
announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result)
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
sm.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("subagent:%s", task.ID),
// Format: "original_channel:original_chat_id" for routing back
ChatID: fmt.Sprintf("%s:%s", task.OriginChannel, task.OriginChatID),
Content: announceContent,
})
}
}
func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
+9 -14
View File
@@ -5,7 +5,6 @@ import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -47,7 +46,7 @@ func (m *MockLLMProvider) GetContextWindow() int {
func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.SetLLMOptions(2048, 0.6)
tool := NewSubagentTool(manager)
@@ -73,7 +72,7 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
// TestSubagentTool_Name verifies tool name
func TestSubagentTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
if tool.Name() != "subagent" {
@@ -84,7 +83,7 @@ func TestSubagentTool_Name(t *testing.T) {
// TestSubagentTool_Description verifies tool description
func TestSubagentTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
desc := tool.Description()
@@ -99,7 +98,7 @@ func TestSubagentTool_Description(t *testing.T) {
// TestSubagentTool_Parameters verifies tool parameters schema
func TestSubagentTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
params := tool.Parameters()
@@ -149,8 +148,7 @@ func TestSubagentTool_Parameters(t *testing.T) {
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
@@ -204,8 +202,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
// TestSubagentTool_Execute_NoLabel tests execution without label
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
ctx := context.Background()
@@ -228,7 +225,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
ctx := context.Background()
@@ -278,8 +275,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) {
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
channel := "test-channel"
@@ -304,8 +300,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a mock provider that returns very long content
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
ctx := context.Background()