mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream/main and resolve conflicts in .env.example
This commit is contained in:
+54
-1
@@ -605,7 +605,60 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
// Second pass: ensure every assistant message with tool_calls has matching
|
||||
// tool result messages following it. This is required by strict providers
|
||||
// like DeepSeek that enforce: "An assistant message with 'tool_calls' must
|
||||
// be followed by tool messages responding to each 'tool_call_id'."
|
||||
final := make([]providers.Message, 0, len(sanitized))
|
||||
for i := 0; i < len(sanitized); i++ {
|
||||
msg := sanitized[i]
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
// Collect expected tool_call IDs
|
||||
expected := make(map[string]bool, len(msg.ToolCalls))
|
||||
for _, tc := range msg.ToolCalls {
|
||||
expected[tc.ID] = false
|
||||
}
|
||||
|
||||
// Check following messages for matching tool results
|
||||
toolMsgCount := 0
|
||||
for j := i + 1; j < len(sanitized); j++ {
|
||||
if sanitized[j].Role != "tool" {
|
||||
break
|
||||
}
|
||||
toolMsgCount++
|
||||
if _, exists := expected[sanitized[j].ToolCallID]; exists {
|
||||
expected[sanitized[j].ToolCallID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// If any tool_call_id is missing, drop this assistant message and its partial tool messages
|
||||
allFound := true
|
||||
for toolCallID, found := range expected {
|
||||
if !found {
|
||||
allFound = false
|
||||
logger.DebugCF(
|
||||
"agent",
|
||||
"Dropping assistant message with incomplete tool results",
|
||||
map[string]any{
|
||||
"missing_tool_call_id": toolCallID,
|
||||
"expected_count": len(expected),
|
||||
"found_count": toolMsgCount,
|
||||
},
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allFound {
|
||||
// Skip this assistant message and its tool messages
|
||||
i += toolMsgCount
|
||||
continue
|
||||
}
|
||||
}
|
||||
final = append(final, msg)
|
||||
}
|
||||
|
||||
return final
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) AddToolResult(
|
||||
|
||||
@@ -207,3 +207,77 @@ func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSanitizeHistoryForProvider_IncompleteToolResults tests the forward validation
|
||||
// that ensures assistant messages with tool_calls have ALL matching tool results.
|
||||
// This fixes the DeepSeek error: "An assistant message with 'tool_calls' must be
|
||||
// followed by tool messages responding to each 'tool_call_id'."
|
||||
func TestSanitizeHistoryForProvider_IncompleteToolResults(t *testing.T) {
|
||||
// Assistant expects tool results for both A and B, but only A is present
|
||||
history := []providers.Message{
|
||||
msg("user", "do two things"),
|
||||
assistantWithTools("A", "B"),
|
||||
toolResult("A"),
|
||||
// toolResult("B") is missing - this would cause DeepSeek to fail
|
||||
msg("user", "next question"),
|
||||
msg("assistant", "answer"),
|
||||
}
|
||||
|
||||
result := sanitizeHistoryForProvider(history)
|
||||
// The assistant message with incomplete tool results should be dropped,
|
||||
// along with its partial tool result. The remaining messages are:
|
||||
// user ("do two things"), user ("next question"), assistant ("answer")
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(t, result, "user", "user", "assistant")
|
||||
}
|
||||
|
||||
// TestSanitizeHistoryForProvider_MissingAllToolResults tests the case where
|
||||
// an assistant message has tool_calls but no tool results follow at all.
|
||||
func TestSanitizeHistoryForProvider_MissingAllToolResults(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
msg("user", "do something"),
|
||||
assistantWithTools("A"),
|
||||
// No tool results at all
|
||||
msg("user", "hello"),
|
||||
msg("assistant", "hi"),
|
||||
}
|
||||
|
||||
result := sanitizeHistoryForProvider(history)
|
||||
// The assistant message with no tool results should be dropped.
|
||||
// Remaining: user ("do something"), user ("hello"), assistant ("hi")
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(t, result, "user", "user", "assistant")
|
||||
}
|
||||
|
||||
// TestSanitizeHistoryForProvider_PartialToolResultsInMiddle tests that
|
||||
// incomplete tool results in the middle of a conversation are properly handled.
|
||||
func TestSanitizeHistoryForProvider_PartialToolResultsInMiddle(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
msg("user", "first"),
|
||||
assistantWithTools("A"),
|
||||
toolResult("A"),
|
||||
msg("assistant", "done"),
|
||||
msg("user", "second"),
|
||||
assistantWithTools("B", "C"),
|
||||
toolResult("B"),
|
||||
// toolResult("C") is missing
|
||||
msg("user", "third"),
|
||||
assistantWithTools("D"),
|
||||
toolResult("D"),
|
||||
msg("assistant", "all done"),
|
||||
}
|
||||
|
||||
result := sanitizeHistoryForProvider(history)
|
||||
// First round is complete (user, assistant+tools, tool, assistant),
|
||||
// second round is incomplete and dropped (assistant+tools, partial tool),
|
||||
// third round is complete (user, assistant+tools, tool, assistant).
|
||||
// Remaining: user, assistant, tool, assistant, user, user, assistant, tool, assistant
|
||||
if len(result) != 9 {
|
||||
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "user", "assistant", "tool", "assistant")
|
||||
}
|
||||
|
||||
@@ -37,6 +37,14 @@ type AgentInstance struct {
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
|
||||
// Router is non-nil when model routing is configured and the light model
|
||||
// was successfully resolved. It scores each incoming message and decides
|
||||
// whether to route to LightCandidates or stay with Candidates.
|
||||
Router *routing.Router
|
||||
// LightCandidates holds the resolved provider candidates for the light model.
|
||||
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
|
||||
LightCandidates []providers.FallbackCandidate
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -180,6 +188,25 @@ func NewAgentInstance(
|
||||
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
|
||||
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
Threshold: rc.Threshold,
|
||||
})
|
||||
lightCandidates = resolved
|
||||
} else {
|
||||
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
|
||||
rc.LightModel, agentID)
|
||||
}
|
||||
}
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
@@ -200,6 +227,8 @@ func NewAgentInstance(
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+246
-117
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -46,6 +47,7 @@ type AgentLoop struct {
|
||||
channelManager *channels.Manager
|
||||
mediaStore media.MediaStore
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -61,7 +63,15 @@ type processOptions struct {
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
const (
|
||||
defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
metadataKeyParentPeerKind = "parent_peer_kind"
|
||||
metadataKeyParentPeerID = "parent_peer_id"
|
||||
)
|
||||
|
||||
func NewAgentLoop(
|
||||
cfg *config.Config,
|
||||
@@ -84,14 +94,17 @@ func NewAgentLoop(
|
||||
stateManager = state.NewManager(defaultAgent.Workspace)
|
||||
}
|
||||
|
||||
return &AgentLoop{
|
||||
al := &AgentLoop{
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
}
|
||||
|
||||
return al
|
||||
}
|
||||
|
||||
// registerSharedTools registers tools that are shared across all agents (web, message, spawn).
|
||||
@@ -170,6 +183,17 @@ func registerSharedTools(
|
||||
agent.Tools.Register(messageTool)
|
||||
}
|
||||
|
||||
// Send file tool (outbound media via MediaStore — store injected later by SetMediaStore)
|
||||
if cfg.Tools.IsToolEnabled("send_file") {
|
||||
sendFileTool := tools.NewSendFileTool(
|
||||
agent.Workspace,
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
cfg.Agents.Defaults.GetMaxMediaSize(),
|
||||
nil,
|
||||
)
|
||||
agent.Tools.Register(sendFileTool)
|
||||
}
|
||||
|
||||
// Skill discovery and installation tools
|
||||
skills_enabled := cfg.Tools.IsToolEnabled("skills")
|
||||
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
|
||||
@@ -196,7 +220,7 @@ func registerSharedTools(
|
||||
// Spawn tool with allowlist checker
|
||||
if cfg.Tools.IsToolEnabled("spawn") {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
@@ -371,6 +395,13 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
// SetMediaStore injects a MediaStore for media lifecycle management.
|
||||
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
|
||||
al.mediaStore = s
|
||||
|
||||
// Propagate store to send_file tools in all agents.
|
||||
al.registry.ForEachTool("send_file", func(t tools.Tool) {
|
||||
if sf, ok := t.(*tools.SendFileTool); ok {
|
||||
sf.SetMediaStore(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
|
||||
@@ -549,27 +580,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return al.processSystemMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// Check for commands
|
||||
if response, handled := al.handleCommand(ctx, msg); handled {
|
||||
route, agent, routeErr := al.resolveMessageRoute(msg)
|
||||
|
||||
// Commands are checked before requiring a successful route.
|
||||
// Global commands (/help, /show, /switch) work even when routing fails;
|
||||
// context-dependent commands check their own Runtime fields and report
|
||||
// "unavailable" when the required capability is nil.
|
||||
if response, handled := al.handleCommand(ctx, msg, agent); handled {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Route to determine agent and session key
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
AccountID: msg.Metadata["account_id"],
|
||||
Peer: extractPeer(msg),
|
||||
ParentPeer: extractParentPeer(msg),
|
||||
GuildID: msg.Metadata["guild_id"],
|
||||
TeamID: msg.Metadata["team_id"],
|
||||
})
|
||||
|
||||
agent, ok := al.registry.GetAgent(route.AgentID)
|
||||
if !ok {
|
||||
agent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
|
||||
if routeErr != nil {
|
||||
return "", routeErr
|
||||
}
|
||||
|
||||
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
|
||||
@@ -579,17 +601,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron)
|
||||
sessionKey := route.SessionKey
|
||||
if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") {
|
||||
sessionKey = msg.SessionKey
|
||||
}
|
||||
// Resolve session key from route, while preserving explicit agent-scoped keys.
|
||||
scopeKey := resolveScopeKey(route, msg.SessionKey)
|
||||
sessionKey := scopeKey
|
||||
|
||||
logger.InfoCF("agent", "Routed message",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"session_key": sessionKey,
|
||||
"matched_by": route.MatchedBy,
|
||||
"agent_id": agent.ID,
|
||||
"scope_key": scopeKey,
|
||||
"session_key": sessionKey,
|
||||
"matched_by": route.MatchedBy,
|
||||
"route_agent": route.AgentID,
|
||||
"route_channel": route.Channel,
|
||||
})
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
@@ -604,6 +627,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
AccountID: inboundMetadata(msg, metadataKeyAccountID),
|
||||
Peer: extractPeer(msg),
|
||||
ParentPeer: extractParentPeer(msg),
|
||||
GuildID: inboundMetadata(msg, metadataKeyGuildID),
|
||||
TeamID: inboundMetadata(msg, metadataKeyTeamID),
|
||||
})
|
||||
|
||||
agent, ok := al.registry.GetAgent(route.AgentID)
|
||||
if !ok {
|
||||
agent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
if agent == nil {
|
||||
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
|
||||
}
|
||||
|
||||
return route, agent, nil
|
||||
}
|
||||
|
||||
func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string {
|
||||
if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) {
|
||||
return msgSessionKey
|
||||
}
|
||||
return route.SessionKey
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processSystemMessage(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
@@ -675,9 +726,8 @@ func (al *AgentLoop) runAgentLoop(
|
||||
agent *AgentInstance,
|
||||
opts processOptions,
|
||||
) (string, error) {
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels and cli)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
// Don't record internal channels (cli, system, subagent)
|
||||
if !constants.IsInternalChannel(opts.Channel) {
|
||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||
@@ -824,6 +874,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
// Determine effective model tier for this conversation turn.
|
||||
// selectCandidates evaluates routing once and the decision is sticky for
|
||||
// all tool-follow-up iterations within the same turn so that a multi-step
|
||||
// tool chain doesn't switch models mid-way through.
|
||||
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
|
||||
|
||||
for iteration < agent.MaxIterations {
|
||||
iteration++
|
||||
|
||||
@@ -842,7 +898,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": agent.Model,
|
||||
"model": activeModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": agent.MaxTokens,
|
||||
@@ -858,7 +914,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
// Call LLM with fallback chain if candidates are configured.
|
||||
// Call LLM with fallback chain if multiple candidates are configured.
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
@@ -879,10 +935,10 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
agent.Candidates,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
|
||||
},
|
||||
@@ -900,7 +956,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts)
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
|
||||
}
|
||||
|
||||
// Retry loop for context/token errors
|
||||
@@ -999,9 +1055,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"target_channel": al.targetReasoningChannelID(opts.Channel),
|
||||
"channel": opts.Channel,
|
||||
})
|
||||
// Check if no tool calls - we're done
|
||||
// Check if no tool calls - then check reasoning content if any
|
||||
if len(response.ToolCalls) == 0 {
|
||||
finalContent = response.Content
|
||||
if finalContent == "" && response.ReasoningContent != "" {
|
||||
finalContent = response.ReasoningContent
|
||||
}
|
||||
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
@@ -1087,15 +1146,47 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Create async callback for tools that implement AsyncExecutor
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
// Create async callback for tools that implement AsyncExecutor.
|
||||
// When the background work completes, this publishes the result
|
||||
// as an inbound system message so processSystemMessage routes it
|
||||
// back to the user via the normal agent loop.
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
// Send ForUser content directly to the user (immediate feedback),
|
||||
// mirroring the synchronous tool execution path.
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(result.ForUser),
|
||||
})
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer outCancel()
|
||||
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: result.ForUser,
|
||||
})
|
||||
}
|
||||
|
||||
// Determine content for the agent loop (ForLLM or error).
|
||||
content := result.ForLLM
|
||||
if content == "" && result.Err != nil {
|
||||
content = result.Err.Error()
|
||||
}
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(content),
|
||||
"channel": opts.Channel,
|
||||
})
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", tc.Name),
|
||||
ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
toolResult := agent.Tools.ExecuteWithContext(
|
||||
@@ -1128,7 +1219,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
// If tool returned media refs, publish them as outbound media
|
||||
if len(r.result.Media) > 0 && opts.SendResponse {
|
||||
if len(r.result.Media) > 0 {
|
||||
parts := make([]bus.MediaPart, 0, len(r.result.Media))
|
||||
for _, ref := range r.result.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
@@ -1169,6 +1260,44 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return finalContent, iteration, nil
|
||||
}
|
||||
|
||||
// selectCandidates returns the model candidates and resolved model name to use
|
||||
// for a conversation turn. When model routing is configured and the incoming
|
||||
// message scores below the complexity threshold, it returns the light model
|
||||
// candidates instead of the primary ones.
|
||||
//
|
||||
// The returned (candidates, model) pair is used for all LLM calls within one
|
||||
// turn — tool follow-up iterations use the same tier as the initial call so
|
||||
// that a multi-step tool chain doesn't switch models mid-way.
|
||||
func (al *AgentLoop) selectCandidates(
|
||||
agent *AgentInstance,
|
||||
userMsg string,
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, agent.Model
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
if !usedLight {
|
||||
logger.DebugCF("agent", "Model routing: primary model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, agent.Model
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"light_model": agent.Router.LightModel(),
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, agent.Router.LightModel()
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
@@ -1460,94 +1589,87 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if !strings.HasPrefix(content, "/") {
|
||||
func (al *AgentLoop) handleCommand(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
agent *AgentInstance,
|
||||
) (string, bool) {
|
||||
if !commands.HasCommandPrefix(msg.Content) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := strings.Fields(content)
|
||||
if len(parts) == 0 {
|
||||
if al.cmdRegistry == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cmd := parts[0]
|
||||
args := parts[1:]
|
||||
rt := al.buildCommandsRuntime(agent)
|
||||
executor := commands.NewExecutor(al.cmdRegistry, rt)
|
||||
|
||||
switch cmd {
|
||||
case "/show":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /show [model|channel|agents]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "model":
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
|
||||
case "channel":
|
||||
return fmt.Sprintf("Current channel: %s", msg.Channel), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown show target: %s", args[0]), true
|
||||
}
|
||||
var commandReply string
|
||||
result := executor.Execute(ctx, commands.Request{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
SenderID: msg.SenderID,
|
||||
Text: msg.Content,
|
||||
Reply: func(text string) error {
|
||||
commandReply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
case "/list":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /list [models|channels|agents]", true
|
||||
switch result.Outcome {
|
||||
case commands.OutcomeHandled:
|
||||
if result.Err != nil {
|
||||
return mapCommandError(result), true
|
||||
}
|
||||
switch args[0] {
|
||||
case "models":
|
||||
return "Available models: configured in config.json per agent", true
|
||||
case "channels":
|
||||
if commandReply != "" {
|
||||
return commandReply, true
|
||||
}
|
||||
return "", true
|
||||
default: // OutcomePassthrough — let the message fall through to LLM
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime {
|
||||
rt := &commands.Runtime{
|
||||
Config: al.cfg,
|
||||
ListAgentIDs: al.registry.ListAgentIDs,
|
||||
ListDefinitions: al.cmdRegistry.Definitions,
|
||||
GetEnabledChannels: func() []string {
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
return nil
|
||||
}
|
||||
channels := al.channelManager.GetEnabledChannels()
|
||||
if len(channels) == 0 {
|
||||
return "No channels enabled", true
|
||||
}
|
||||
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown list target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/switch":
|
||||
if len(args) < 3 || args[1] != "to" {
|
||||
return "Usage: /switch [model|channel] to <name>", true
|
||||
}
|
||||
target := args[0]
|
||||
value := args[2]
|
||||
|
||||
switch target {
|
||||
case "model":
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
oldModel := defaultAgent.Model
|
||||
defaultAgent.Model = value
|
||||
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
|
||||
case "channel":
|
||||
return al.channelManager.GetEnabledChannels()
|
||||
},
|
||||
SwitchChannel: func(value string) error {
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
return fmt.Errorf("channel manager not initialized")
|
||||
}
|
||||
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
|
||||
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
|
||||
return fmt.Errorf("channel '%s' not found or not enabled", value)
|
||||
}
|
||||
return fmt.Sprintf("Switched target channel to %s", value), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown switch target: %s", target), true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if agent != nil {
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, al.cfg.Agents.Defaults.Provider
|
||||
}
|
||||
rt.SwitchModel = func(value string) (string, error) {
|
||||
oldModel := agent.Model
|
||||
agent.Model = value
|
||||
return oldModel, nil
|
||||
}
|
||||
}
|
||||
return rt
|
||||
}
|
||||
|
||||
return "", false
|
||||
func mapCommandError(result commands.ExecuteResult) string {
|
||||
if result.Command == "" {
|
||||
return fmt.Sprintf("Failed to execute command: %v", result.Err)
|
||||
}
|
||||
return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err)
|
||||
}
|
||||
|
||||
// extractPeer extracts the routing peer from the inbound message's structured Peer field.
|
||||
@@ -1566,10 +1688,17 @@ func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID}
|
||||
}
|
||||
|
||||
func inboundMetadata(msg bus.InboundMessage, key string) string {
|
||||
if msg.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
return msg.Metadata[key]
|
||||
}
|
||||
|
||||
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
|
||||
func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
parentKind := msg.Metadata["parent_peer_kind"]
|
||||
parentID := msg.Metadata["parent_peer_id"]
|
||||
parentKind := inboundMetadata(msg, metadataKeyParentPeerKind)
|
||||
parentID := inboundMetadata(msg, metadataKeyParentPeerID)
|
||||
if parentKind == "" || parentID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -318,6 +319,29 @@ func (m *simpleMockProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type countingMockProvider struct {
|
||||
response string
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *countingMockProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
return &providers.LLMResponse{
|
||||
Content: m.response,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *countingMockProvider) GetDefaultModel() string {
|
||||
return "counting-mock-model"
|
||||
}
|
||||
|
||||
// mockCustomTool is a simple mock tool for registration testing
|
||||
type mockCustomTool struct{}
|
||||
|
||||
@@ -359,6 +383,198 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms
|
||||
|
||||
const responseTimeout = 3 * time.Second
|
||||
|
||||
func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "ok"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
Peer: extractPeer(msg),
|
||||
})
|
||||
sessionKey := route.SessionKey
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
|
||||
helper := testHelper{al: al}
|
||||
_ = helper.executeAndGetResponse(t, context.Background(), msg)
|
||||
|
||||
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected session history len=2, got %d", len(history))
|
||||
}
|
||||
if history[0].Role != "user" || history[0].Content != "hello" {
|
||||
t.Fatalf("unexpected first message in session: %+v", history[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-channel-peer",
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
baseMsg := bus.InboundMessage{
|
||||
Channel: "whatsapp",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/show channel",
|
||||
Peer: baseMsg.Peer,
|
||||
})
|
||||
if showResp != "Current Channel: whatsapp" {
|
||||
t.Fatalf("unexpected /show reply: %q", showResp)
|
||||
}
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf("LLM should not be called for handled command, calls=%d", provider.calls)
|
||||
}
|
||||
|
||||
fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/foo",
|
||||
Peer: baseMsg.Peer,
|
||||
})
|
||||
if fooResp != "LLM reply" {
|
||||
t.Fatalf("unexpected /foo reply: %q", fooResp)
|
||||
}
|
||||
if provider.calls != 1 {
|
||||
t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls)
|
||||
}
|
||||
|
||||
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/new",
|
||||
Peer: baseMsg.Peer,
|
||||
})
|
||||
if newResp != "LLM reply" {
|
||||
t.Fatalf("unexpected /new reply: %q", newResp)
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("LLM should be called for passthrough /new command, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
Model: "before-switch",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to after-switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
|
||||
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
||||
}
|
||||
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf("LLM should not be called for /switch and /show, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// AgentRegistry manages multiple agent instances and routes messages to them.
|
||||
@@ -100,6 +101,19 @@ func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bo
|
||||
return false
|
||||
}
|
||||
|
||||
// ForEachTool calls fn for every tool registered under the given name
|
||||
// across all agents. This is useful for propagating dependencies (e.g.
|
||||
// MediaStore) to tools after registry construction.
|
||||
func (r *AgentRegistry) ForEachTool(name string, fn func(tools.Tool)) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
for _, agent := range r.agents {
|
||||
if t, ok := agent.Tools.Get(name); ok {
|
||||
fn(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetDefaultAgent returns the default agent instance.
|
||||
func (r *AgentRegistry) GetDefaultAgent() *AgentInstance {
|
||||
r.mu.RLock()
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
anthropicBetaHeader = "oauth-2025-04-20"
|
||||
anthropicAPIVersion = "2023-06-01"
|
||||
)
|
||||
|
||||
// anthropicUsageURL is the endpoint for fetching OAuth usage stats.
|
||||
// It is a var (not const) to allow overriding in tests.
|
||||
var anthropicUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
func setAnthropicUsageURL(url string) { anthropicUsageURL = url }
|
||||
|
||||
type AnthropicUsage struct {
|
||||
FiveHourUtilization float64
|
||||
SevenDayUtilization float64
|
||||
}
|
||||
|
||||
func FetchAnthropicUsage(token string) (*AnthropicUsage, error) {
|
||||
req, err := http.NewRequest("GET", anthropicUsageURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Anthropic-Version", anthropicAPIVersion)
|
||||
req.Header.Set("Anthropic-Beta", anthropicBetaHeader)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading usage response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode == http.StatusForbidden {
|
||||
return nil, fmt.Errorf("insufficient scope: usage endpoint requires oauth scope")
|
||||
}
|
||||
return nil, fmt.Errorf("usage request failed (%d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
FiveHour struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
} `json:"five_hour"`
|
||||
SevenDay struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
} `json:"seven_day"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parsing usage response: %w", err)
|
||||
}
|
||||
|
||||
return &AnthropicUsage{
|
||||
FiveHourUtilization: result.FiveHour.Utilization,
|
||||
SevenDayUtilization: result.SevenDay.Utilization,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFetchAnthropicUsage_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer test-token")
|
||||
}
|
||||
if got := r.Header.Get("Anthropic-Beta"); got != anthropicBetaHeader {
|
||||
t.Errorf("Anthropic-Beta = %q, want %q", got, anthropicBetaHeader)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"five_hour":{"utilization":0.42},"seven_day":{"utilization":0.85}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Temporarily override the URL by using the test server
|
||||
origURL := anthropicUsageURL
|
||||
defer func() { setAnthropicUsageURL(origURL) }()
|
||||
setAnthropicUsageURL(srv.URL)
|
||||
|
||||
usage, err := FetchAnthropicUsage("test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if usage.FiveHourUtilization != 0.42 {
|
||||
t.Errorf("FiveHourUtilization = %v, want 0.42", usage.FiveHourUtilization)
|
||||
}
|
||||
if usage.SevenDayUtilization != 0.85 {
|
||||
t.Errorf("SevenDayUtilization = %v, want 0.85", usage.SevenDayUtilization)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAnthropicUsage_Forbidden(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error":"forbidden"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
origURL := anthropicUsageURL
|
||||
defer func() { setAnthropicUsageURL(origURL) }()
|
||||
setAnthropicUsageURL(srv.URL)
|
||||
|
||||
_, err := FetchAnthropicUsage("test-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "insufficient scope") {
|
||||
t.Errorf("expected 'insufficient scope' error, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAnthropicUsage_ServerError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`internal error`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
origURL := anthropicUsageURL
|
||||
defer func() { setAnthropicUsageURL(origURL) }()
|
||||
setAnthropicUsageURL(srv.URL)
|
||||
|
||||
_, err := FetchAnthropicUsage("test-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 500, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("expected error containing '500', got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAnthropicUsage_MalformedJSON(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`not json`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
origURL := anthropicUsageURL
|
||||
defer func() { setAnthropicUsageURL(origURL) }()
|
||||
setAnthropicUsageURL(srv.URL)
|
||||
|
||||
_, err := FetchAnthropicUsage("test-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for malformed JSON, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "parsing usage response") {
|
||||
t.Errorf("expected 'parsing usage response' error, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
@@ -31,6 +31,35 @@ func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func LoginSetupToken(r io.Reader) (*AuthCredential, error) {
|
||||
fmt.Println("Paste your setup token from `claude setup-token`:")
|
||||
fmt.Print("> ")
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("reading token: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("no input received")
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if !strings.HasPrefix(token, "sk-ant-oat01-") {
|
||||
return nil, fmt.Errorf("invalid setup token: expected prefix sk-ant-oat01-")
|
||||
}
|
||||
|
||||
if len(token) < 80 {
|
||||
return nil, fmt.Errorf("invalid setup token: too short (expected at least 80 characters)")
|
||||
}
|
||||
|
||||
return &AuthCredential{
|
||||
AccessToken: token,
|
||||
Provider: "anthropic",
|
||||
AuthMethod: "oauth",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func providerDisplayName(provider string) string {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoginSetupToken(t *testing.T) {
|
||||
// A valid token: correct prefix + at least 80 chars
|
||||
validToken := "sk-ant-oat01-" + strings.Repeat("a", 80)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr string
|
||||
}{
|
||||
{"valid token", validToken, ""},
|
||||
{"empty input", "", "expected prefix sk-ant-oat01-"},
|
||||
{"wrong prefix", "sk-ant-api-" + strings.Repeat("a", 80), "expected prefix sk-ant-oat01-"},
|
||||
{"too short", "sk-ant-oat01-short", "too short"},
|
||||
{"whitespace only", " ", "expected prefix sk-ant-oat01-"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := strings.NewReader(tt.input + "\n")
|
||||
cred, err := LoginSetupToken(r)
|
||||
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.wantErr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cred.AccessToken != validToken {
|
||||
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, validToken)
|
||||
}
|
||||
if cred.Provider != "anthropic" {
|
||||
t.Errorf("Provider = %q, want %q", cred.Provider, "anthropic")
|
||||
}
|
||||
if cred.AuthMethod != "oauth" {
|
||||
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginSetupToken_EmptyReader(t *testing.T) {
|
||||
r := strings.NewReader("")
|
||||
_, err := LoginSetupToken(r)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty reader, got nil")
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -26,6 +27,12 @@ const (
|
||||
sendTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call)
|
||||
channelRefRe = regexp.MustCompile(`<#(\d+)>`)
|
||||
msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`)
|
||||
)
|
||||
|
||||
type DiscordChannel struct {
|
||||
*channels.BaseChannel
|
||||
session *discordgo.Session
|
||||
@@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
content = c.stripBotMention(content)
|
||||
}
|
||||
|
||||
// Resolve Discord refs in main content before concatenation to avoid
|
||||
// double-expanding links that appear in the referenced message.
|
||||
content = c.resolveDiscordRefs(s, content, m.GuildID)
|
||||
|
||||
// Prepend referenced (quoted) message content if this is a reply
|
||||
if m.MessageReference != nil && m.ReferencedMessage != nil {
|
||||
refContent := m.ReferencedMessage.Content
|
||||
if refContent != "" {
|
||||
refAuthor := "unknown"
|
||||
if m.ReferencedMessage.Author != nil {
|
||||
refAuthor = m.ReferencedMessage.Author.Username
|
||||
}
|
||||
refContent = c.resolveDiscordRefs(s, refContent, m.GuildID)
|
||||
content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s",
|
||||
refAuthor, refContent, content)
|
||||
}
|
||||
}
|
||||
|
||||
senderID := m.Author.ID
|
||||
|
||||
mediaPaths := make([]string, 0, len(m.Attachments))
|
||||
@@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and
|
||||
// expands Discord message links to show the linked message content.
|
||||
// Only links pointing to the same guild are expanded to prevent cross-guild leakage.
|
||||
func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string {
|
||||
// 1. Resolve channel references: <#id> → #channel-name
|
||||
text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string {
|
||||
parts := channelRefRe.FindStringSubmatch(match)
|
||||
if len(parts) < 2 {
|
||||
return match
|
||||
}
|
||||
// Prefer session state cache to avoid API calls
|
||||
if ch, err := s.State.Channel(parts[1]); err == nil {
|
||||
return "#" + ch.Name
|
||||
}
|
||||
if ch, err := s.Channel(parts[1]); err == nil {
|
||||
return "#" + ch.Name
|
||||
}
|
||||
return match
|
||||
})
|
||||
|
||||
// 2. Expand Discord message links (max 3, same guild only)
|
||||
matches := msgLinkRe.FindAllStringSubmatch(text, 3)
|
||||
for _, m := range matches {
|
||||
if len(m) < 4 {
|
||||
continue
|
||||
}
|
||||
linkGuildID, channelID, messageID := m[1], m[2], m[3]
|
||||
// Security: only expand links from the same guild
|
||||
if linkGuildID != guildID {
|
||||
continue
|
||||
}
|
||||
msg, err := s.ChannelMessage(channelID, messageID)
|
||||
if err != nil || msg == nil || msg.Content == "" {
|
||||
continue
|
||||
}
|
||||
author := "unknown"
|
||||
if msg.Author != nil {
|
||||
author = msg.Author.Username
|
||||
}
|
||||
text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content)
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
// stripBotMention removes the bot mention from the message content.
|
||||
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
|
||||
func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChannelRefRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantID string
|
||||
wantOK bool
|
||||
}{
|
||||
{"basic channel ref", "<#123456789>", "123456789", true},
|
||||
{"long id", "<#9876543210123456>", "9876543210123456", true},
|
||||
{"no match plain text", "hello world", "", false},
|
||||
{"no match partial", "<#>", "", false},
|
||||
{"no match letters", "<#abc>", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := channelRefRe.FindStringSubmatch(tt.input)
|
||||
if tt.wantOK {
|
||||
if len(matches) < 2 || matches[1] != tt.wantID {
|
||||
t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID)
|
||||
}
|
||||
} else {
|
||||
if len(matches) >= 2 {
|
||||
t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgLinkRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantGuild string
|
||||
wantChan string
|
||||
wantMsg string
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
"discord.com link",
|
||||
"https://discord.com/channels/111/222/333",
|
||||
"111", "222", "333", true,
|
||||
},
|
||||
{
|
||||
"discordapp.com link",
|
||||
"https://discordapp.com/channels/111/222/333",
|
||||
"111", "222", "333", true,
|
||||
},
|
||||
{
|
||||
"real world ids",
|
||||
"check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please",
|
||||
"9000000000000001", "9000000000000002", "9000000000000003", true,
|
||||
},
|
||||
{"no match http", "http://discord.com/channels/1/2/3", "", "", "", false},
|
||||
{"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false},
|
||||
{"no match plain text", "hello world", "", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := msgLinkRe.FindStringSubmatch(tt.input)
|
||||
if tt.wantOK {
|
||||
if len(matches) < 4 {
|
||||
t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s",
|
||||
tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg)
|
||||
}
|
||||
if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg {
|
||||
t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s",
|
||||
tt.input, matches[1], matches[2], matches[3],
|
||||
tt.wantGuild, tt.wantChan, tt.wantMsg)
|
||||
}
|
||||
} else {
|
||||
if len(matches) >= 4 {
|
||||
t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgLinkRegex_MultipleMatches(t *testing.T) {
|
||||
input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12"
|
||||
matches := msgLinkRe.FindAllStringSubmatch(input, 3)
|
||||
if len(matches) != 3 {
|
||||
t.Fatalf("expected 3 matches (capped), got %d", len(matches))
|
||||
}
|
||||
// Verify the 3rd match is 7/8/9 (not 10/11/12)
|
||||
if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" {
|
||||
t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2])
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
package channels
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
// TypingCapable — channels that can show a typing/thinking indicator.
|
||||
// StartTyping begins the indicator and returns a stop function.
|
||||
@@ -39,3 +43,10 @@ type PlaceholderRecorder interface {
|
||||
RecordTypingStop(channel, chatID string, stop func())
|
||||
RecordReactionUndo(channel, chatID string, undo func())
|
||||
}
|
||||
|
||||
// CommandRegistrarCapable is implemented by channels that can register
|
||||
// command menus with their upstream platform (e.g. Telegram BotCommand).
|
||||
// Channels that do not support platform-level command menus can ignore it.
|
||||
type CommandRegistrarCapable interface {
|
||||
RegisterCommands(ctx context.Context, defs []commands.Definition) error
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
type mockRegistrar struct{}
|
||||
|
||||
func (mockRegistrar) RegisterCommands(context.Context, []commands.Definition) error { return nil }
|
||||
|
||||
func TestCommandRegistrarCapable_Compiles(t *testing.T) {
|
||||
var _ CommandRegistrarCapable = mockRegistrar{}
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
package irc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/ergochat/irc-go/ircevent"
|
||||
"github.com/ergochat/irc-go/ircmsg"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// onConnect is called after a successful connection (and on reconnect).
|
||||
func (c *IRCChannel) onConnect(conn *ircevent.Connection) {
|
||||
// NickServ auth (only if SASL is not configured)
|
||||
if c.config.NickServPassword != "" && c.config.SASLUser == "" {
|
||||
conn.Privmsg("NickServ", "IDENTIFY "+c.config.NickServPassword)
|
||||
}
|
||||
|
||||
// Join configured channels
|
||||
for _, ch := range c.config.Channels {
|
||||
conn.Join(ch)
|
||||
logger.InfoCF("irc", "Joined IRC channel", map[string]any{
|
||||
"channel": ch,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// onPrivmsg handles incoming PRIVMSG events.
|
||||
func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
if len(e.Params) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
nick := e.Nick()
|
||||
currentNick := conn.CurrentNick()
|
||||
|
||||
// Ignore own messages
|
||||
if strings.EqualFold(nick, currentNick) {
|
||||
return
|
||||
}
|
||||
|
||||
target := e.Params[0] // channel name or bot's nick
|
||||
content := e.Params[1] // message text
|
||||
|
||||
// Determine if this is a DM or channel message
|
||||
isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&")
|
||||
|
||||
var chatID string
|
||||
var peer bus.Peer
|
||||
|
||||
if isDM {
|
||||
chatID = nick
|
||||
peer = bus.Peer{Kind: "direct", ID: nick}
|
||||
} else {
|
||||
chatID = target
|
||||
peer = bus.Peer{Kind: "group", ID: target}
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "irc",
|
||||
PlatformID: nick,
|
||||
CanonicalID: identity.BuildCanonicalID("irc", nick),
|
||||
Username: nick,
|
||||
DisplayName: nick,
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(sender) {
|
||||
return
|
||||
}
|
||||
|
||||
// For channel messages, check group trigger (mention detection)
|
||||
if !isDM {
|
||||
isMentioned := isBotMentioned(content, currentNick)
|
||||
if isMentioned {
|
||||
content = stripBotMention(content, currentNick)
|
||||
}
|
||||
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
|
||||
if !respond {
|
||||
return
|
||||
}
|
||||
content = cleaned
|
||||
}
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
messageID := fmt.Sprintf("%s-%d", nick, time.Now().UnixNano())
|
||||
|
||||
metadata := map[string]string{
|
||||
"platform": "irc",
|
||||
"server": c.config.Server,
|
||||
}
|
||||
if !isDM {
|
||||
metadata["channel"] = target
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender)
|
||||
}
|
||||
|
||||
// nickMentionedAt returns the byte index where botNick is mentioned in content
|
||||
// with word-boundary checks, or -1 if not found. Also checks for "nick:" /
|
||||
// "nick," prefix convention.
|
||||
func nickMentionedAt(content, botNick string) int {
|
||||
lower := strings.ToLower(content)
|
||||
lowerNick := strings.ToLower(botNick)
|
||||
|
||||
// "nick:" or "nick," at start (most common IRC convention)
|
||||
if strings.HasPrefix(lower, lowerNick+":") || strings.HasPrefix(lower, lowerNick+",") {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Word-boundary match anywhere in the message
|
||||
idx := strings.Index(lower, lowerNick)
|
||||
if idx < 0 {
|
||||
return -1
|
||||
}
|
||||
runes := []rune(lower)
|
||||
nickRunes := []rune(lowerNick)
|
||||
endIdx := idx + len(string(nickRunes))
|
||||
before := idx == 0 || !unicode.IsLetter(runes[idx-1]) && !unicode.IsDigit(runes[idx-1])
|
||||
after := endIdx >= len(lower) || !unicode.IsLetter(rune(lower[endIdx])) && !unicode.IsDigit(rune(lower[endIdx]))
|
||||
if before && after {
|
||||
return idx
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot's nick appears in the message.
|
||||
func isBotMentioned(content, botNick string) bool {
|
||||
return nickMentionedAt(content, botNick) >= 0
|
||||
}
|
||||
|
||||
// stripBotMention removes "nick: " or "nick, " prefix from content.
|
||||
func stripBotMention(content, botNick string) string {
|
||||
idx := nickMentionedAt(content, botNick)
|
||||
if idx != 0 {
|
||||
return content
|
||||
}
|
||||
lowerNick := strings.ToLower(botNick)
|
||||
lower := strings.ToLower(content)
|
||||
for _, sep := range []string{":", ","} {
|
||||
prefix := lowerNick + sep
|
||||
if strings.HasPrefix(lower, prefix) {
|
||||
return strings.TrimSpace(content[len(prefix):])
|
||||
}
|
||||
}
|
||||
return content
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package irc
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("irc", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
if !cfg.Channels.IRC.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
return NewIRCChannel(cfg.Channels.IRC, b)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,194 @@
|
||||
package irc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ergochat/irc-go/ircevent"
|
||||
"github.com/ergochat/irc-go/ircmsg"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// IRCChannel implements the Channel interface for IRC servers.
|
||||
type IRCChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.IRCConfig
|
||||
conn *ircevent.Connection
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewIRCChannel creates a new IRC channel.
|
||||
func NewIRCChannel(cfg config.IRCConfig, messageBus *bus.MessageBus) (*IRCChannel, error) {
|
||||
if cfg.Server == "" {
|
||||
return nil, fmt.Errorf("irc server is required")
|
||||
}
|
||||
if cfg.Nick == "" {
|
||||
return nil, fmt.Errorf("irc nick is required")
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel("irc", cfg, messageBus, cfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(400),
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &IRCChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start connects to the IRC server and begins listening.
|
||||
func (c *IRCChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("irc", "Starting IRC channel")
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
user := c.config.User
|
||||
if user == "" {
|
||||
user = c.config.Nick
|
||||
}
|
||||
realName := c.config.RealName
|
||||
if realName == "" {
|
||||
realName = c.config.Nick
|
||||
}
|
||||
caps := []string(c.config.RequestCaps)
|
||||
if len(caps) == 0 {
|
||||
caps = []string{"server-time", "message-tags"}
|
||||
}
|
||||
|
||||
conn := &ircevent.Connection{
|
||||
Server: c.config.Server,
|
||||
Nick: c.config.Nick,
|
||||
User: user,
|
||||
RealName: realName,
|
||||
Password: c.config.Password,
|
||||
UseTLS: c.config.TLS,
|
||||
RequestCaps: caps,
|
||||
QuitMessage: "Goodbye",
|
||||
Debug: false,
|
||||
Log: nil,
|
||||
}
|
||||
|
||||
if c.config.TLS {
|
||||
conn.TLSConfig = &tls.Config{
|
||||
ServerName: extractHost(c.config.Server),
|
||||
}
|
||||
}
|
||||
|
||||
// SASL auth (takes priority over NickServ)
|
||||
if c.config.SASLUser != "" && c.config.SASLPassword != "" {
|
||||
conn.SASLLogin = c.config.SASLUser
|
||||
conn.SASLPassword = c.config.SASLPassword
|
||||
}
|
||||
|
||||
// Register event handlers
|
||||
conn.AddConnectCallback(func(e ircmsg.Message) {
|
||||
c.onConnect(conn)
|
||||
})
|
||||
conn.AddCallback("PRIVMSG", func(e ircmsg.Message) {
|
||||
c.onPrivmsg(conn, e)
|
||||
})
|
||||
|
||||
if err := conn.Connect(); err != nil {
|
||||
return fmt.Errorf("irc connect failed: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
|
||||
// ircevent.Connection.Loop() handles reconnection internally.
|
||||
go conn.Loop()
|
||||
|
||||
c.SetRunning(true)
|
||||
logger.InfoCF("irc", "IRC channel started", map[string]any{
|
||||
"server": c.config.Server,
|
||||
"nick": c.config.Nick,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop disconnects from the IRC server.
|
||||
func (c *IRCChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("irc", "Stopping IRC channel")
|
||||
c.SetRunning(false)
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Quit()
|
||||
}
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
logger.InfoC("irc", "IRC channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message to an IRC channel or user.
|
||||
func (c *IRCChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
target := msg.ChatID
|
||||
if target == "" {
|
||||
return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(msg.Content) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send each line separately (IRC is line-oriented)
|
||||
lines := strings.Split(msg.Content, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
c.conn.Privmsg(target, line)
|
||||
}
|
||||
|
||||
logger.DebugCF("irc", "Message sent", map[string]any{
|
||||
"target": target,
|
||||
"lines": len(lines),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartTyping implements channels.TypingCapable using IRCv3 +typing client tag.
|
||||
// Requires typing.enabled in config and server support for message-tags capability.
|
||||
func (c *IRCChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
|
||||
noop := func() {}
|
||||
|
||||
if !c.config.Typing.Enabled || !c.IsRunning() || c.conn == nil {
|
||||
return noop, nil
|
||||
}
|
||||
|
||||
// Check if server supports message-tags (required for TAGMSG)
|
||||
if _, ok := c.conn.AcknowledgedCaps()["message-tags"]; !ok {
|
||||
return noop, nil
|
||||
}
|
||||
|
||||
c.conn.SendWithTags(map[string]string{"+typing": "active"}, "TAGMSG", chatID)
|
||||
|
||||
return func() {
|
||||
if c.IsRunning() && c.conn != nil {
|
||||
c.conn.SendWithTags(map[string]string{"+typing": "done"}, "TAGMSG", chatID)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractHost returns the hostname portion of a host:port string.
|
||||
func extractHost(server string) string {
|
||||
host, _, found := strings.Cut(server, ":")
|
||||
if found {
|
||||
return host
|
||||
}
|
||||
return server
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package irc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestNewIRCChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing server", func(t *testing.T) {
|
||||
cfg := config.IRCConfig{Nick: "bot"}
|
||||
_, err := NewIRCChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing server, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing nick", func(t *testing.T) {
|
||||
cfg := config.IRCConfig{Server: "irc.example.com:6667"}
|
||||
_, err := NewIRCChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing nick, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.IRCConfig{
|
||||
Server: "irc.example.com:6667",
|
||||
Nick: "testbot",
|
||||
Channels: []string{"#test"},
|
||||
}
|
||||
ch, err := NewIRCChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "irc" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "irc")
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("new channel should not be running")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
server string
|
||||
want string
|
||||
}{
|
||||
{"irc.libera.chat:6697", "irc.libera.chat"},
|
||||
{"localhost:6667", "localhost"},
|
||||
{"irc.example.com", "irc.example.com"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.server, func(t *testing.T) {
|
||||
got := extractHost(tt.server)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractHost(%q) = %q, want %q", tt.server, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNickMentionedAt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
nick string
|
||||
want int
|
||||
}{
|
||||
{"colon prefix", "bot: hello", "bot", 0},
|
||||
{"comma prefix", "bot, hello", "bot", 0},
|
||||
{"case insensitive", "BOT: hello", "bot", 0},
|
||||
{"word boundary mid", "hey bot what's up", "bot", 4},
|
||||
{"no mention", "hello world", "bot", -1},
|
||||
{"substring mismatch", "robotics are cool", "bot", -1},
|
||||
{"nick at end", "hello bot", "bot", 6},
|
||||
{"empty content", "", "bot", -1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := nickMentionedAt(tt.content, tt.nick)
|
||||
if got != tt.want {
|
||||
t.Errorf("nickMentionedAt(%q, %q) = %d, want %d", tt.content, tt.nick, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBotMentioned(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
nick string
|
||||
want bool
|
||||
}{
|
||||
{"colon prefix", "bot: hello", "bot", true},
|
||||
{"comma prefix", "bot, hello", "bot", true},
|
||||
{"case insensitive", "BOT: hello", "bot", true},
|
||||
{"word boundary mid", "hey bot what's up", "bot", true},
|
||||
{"no mention", "hello world", "bot", false},
|
||||
{"substring mismatch", "robotics are cool", "bot", false},
|
||||
{"nick at end", "hello bot", "bot", true},
|
||||
{"empty content", "", "bot", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isBotMentioned(tt.content, tt.nick)
|
||||
if got != tt.want {
|
||||
t.Errorf("isBotMentioned(%q, %q) = %v, want %v", tt.content, tt.nick, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBotMention(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
nick string
|
||||
want string
|
||||
}{
|
||||
{"colon prefix", "bot: hello there", "bot", "hello there"},
|
||||
{"comma prefix", "bot, help me", "bot", "help me"},
|
||||
{"case insensitive", "BOT: hello", "bot", "hello"},
|
||||
{"no prefix match", "hello bot", "bot", "hello bot"},
|
||||
{"only prefix", "bot:", "bot", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := stripBotMention(tt.content, tt.nick)
|
||||
if got != tt.want {
|
||||
t.Errorf("stripBotMention(%q, %q) = %q, want %q", tt.content, tt.nick, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,7 @@ var channelRateConfig = map[string]float64{
|
||||
"discord": 1,
|
||||
"slack": 1,
|
||||
"line": 10,
|
||||
"irc": 2,
|
||||
}
|
||||
|
||||
type channelWorker struct {
|
||||
@@ -267,6 +268,10 @@ func (m *Manager) initChannels() error {
|
||||
m.initChannel("pico", "Pico")
|
||||
}
|
||||
|
||||
if m.config.Channels.IRC.Enabled && m.config.Channels.IRC.Server != "" {
|
||||
m.initChannel("irc", "IRC")
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
|
||||
"enabled_channels": len(m.channels),
|
||||
})
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
var commandRegistrationBackoff = []time.Duration{
|
||||
5 * time.Second,
|
||||
15 * time.Second,
|
||||
60 * time.Second,
|
||||
5 * time.Minute,
|
||||
10 * time.Minute,
|
||||
}
|
||||
|
||||
func commandRegistrationDelay(attempt int) time.Duration {
|
||||
if len(commandRegistrationBackoff) == 0 {
|
||||
return 0
|
||||
}
|
||||
base := commandRegistrationBackoff[min(attempt, len(commandRegistrationBackoff)-1)]
|
||||
// Full jitter in [0.5, 1.0) to avoid synchronized retries across instances.
|
||||
return time.Duration(float64(base) * (0.5 + rand.Float64()*0.5))
|
||||
}
|
||||
|
||||
// RegisterCommands registers bot commands on Telegram platform.
|
||||
func (c *TelegramChannel) RegisterCommands(ctx context.Context, defs []commands.Definition) error {
|
||||
botCommands := make([]telego.BotCommand, 0, len(defs))
|
||||
for _, def := range defs {
|
||||
if def.Name == "" || def.Description == "" {
|
||||
continue
|
||||
}
|
||||
botCommands = append(botCommands, telego.BotCommand{
|
||||
Command: def.Name,
|
||||
Description: def.Description,
|
||||
})
|
||||
}
|
||||
|
||||
current, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{})
|
||||
if err != nil {
|
||||
// If we can't read current commands, fall through to set them.
|
||||
logger.WarnCF("telegram", "Failed to get current commands, will set unconditionally",
|
||||
map[string]any{"error": err.Error()})
|
||||
} else if slices.Equal(current, botCommands) {
|
||||
logger.DebugCF("telegram", "Bot commands are up to date", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
|
||||
Commands: botCommands,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []commands.Definition) {
|
||||
if len(defs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
register := c.registerFunc
|
||||
if register == nil {
|
||||
register = c.RegisterCommands
|
||||
}
|
||||
|
||||
regCtx, cancel := context.WithCancel(ctx)
|
||||
c.commandRegCancel = cancel
|
||||
|
||||
// Registration runs asynchronously so Telegram message intake is never blocked
|
||||
// by temporary upstream API failures. Retry stops on success or channel shutdown.
|
||||
go func() {
|
||||
attempt := 0
|
||||
timer := time.NewTimer(0)
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
defer timer.Stop()
|
||||
for {
|
||||
err := register(regCtx, defs)
|
||||
if err == nil {
|
||||
logger.InfoCF("telegram", "Telegram commands registered", map[string]any{
|
||||
"count": len(defs),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
delay := commandRegistrationDelay(attempt)
|
||||
logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry_after": delay.String(),
|
||||
})
|
||||
attempt++
|
||||
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(delay)
|
||||
|
||||
select {
|
||||
case <-regCtx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
func TestStartCommandRegistration_DoesNotBlock(t *testing.T) {
|
||||
ch := &TelegramChannel{}
|
||||
started := make(chan struct{}, 1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ch.registerFunc = func(context.Context, []commands.Definition) error {
|
||||
started <- struct{}{}
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
|
||||
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help"}})
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("registration did not start asynchronously")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) {
|
||||
ch := &TelegramChannel{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
origBackoff := commandRegistrationBackoff
|
||||
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
|
||||
defer func() { commandRegistrationBackoff = origBackoff }()
|
||||
|
||||
var attempts atomic.Int32
|
||||
ch.registerFunc = func(context.Context, []commands.Definition) error {
|
||||
n := attempts.Add(1)
|
||||
if n < 3 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}})
|
||||
|
||||
deadline := time.Now().Add(250 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if attempts.Load() >= 3 {
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
if attempts.Load() < 3 {
|
||||
t.Fatalf("expected at least 3 attempts, got %d", attempts.Load())
|
||||
}
|
||||
|
||||
stable := attempts.Load()
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
if attempts.Load() != stable {
|
||||
t.Fatalf("expected retries to stop after success, got %d -> %d", stable, attempts.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) {
|
||||
ch := &TelegramChannel{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
origBackoff := commandRegistrationBackoff
|
||||
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
|
||||
defer func() { commandRegistrationBackoff = origBackoff }()
|
||||
defer cancel()
|
||||
|
||||
var attempts atomic.Int32
|
||||
ch.registerFunc = func(context.Context, []commands.Definition) error {
|
||||
attempts.Add(1)
|
||||
return errors.New("always fail")
|
||||
}
|
||||
|
||||
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}})
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
time.Sleep(20 * time.Millisecond) // allow in-flight attempt to settle
|
||||
stable := attempts.Load()
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
if attempts.Load() != stable {
|
||||
t.Fatalf("expected retries to quiesce after cancel, got %d -> %d", stable, attempts.Load())
|
||||
}
|
||||
}
|
||||
+115
-100
@@ -7,7 +7,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -40,13 +40,15 @@ var (
|
||||
|
||||
type TelegramChannel struct {
|
||||
*channels.BaseChannel
|
||||
bot *telego.Bot
|
||||
bh *th.BotHandler
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
bot *telego.Bot
|
||||
bh *th.BotHandler
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
registerFunc func(context.Context, []commands.Definition) error
|
||||
commandRegCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
@@ -86,14 +88,13 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
telegramCfg,
|
||||
bus,
|
||||
telegramCfg.AllowFrom,
|
||||
channels.WithMaxMessageLength(4096),
|
||||
channels.WithMaxMessageLength(4000),
|
||||
channels.WithGroupTrigger(telegramCfg.GroupTrigger),
|
||||
channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
commands: NewTelegramCommands(bot, cfg),
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
chatIDs: make(map[string]int64),
|
||||
@@ -105,12 +106,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.initBotCommands(c.ctx); err != nil {
|
||||
logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
|
||||
Timeout: 30,
|
||||
})
|
||||
@@ -126,21 +121,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
c.bh = bh
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Start(ctx, message)
|
||||
}, th.CommandEqual("start"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Help(ctx, message)
|
||||
}, th.CommandEqual("help"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Show(ctx, message)
|
||||
}, th.CommandEqual("show"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.List(ctx, message)
|
||||
}, th.CommandEqual("list"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.handleMessage(ctx, &message)
|
||||
}, th.AnyMessage())
|
||||
@@ -150,6 +130,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
c.startCommandRegistration(c.ctx, commands.BuiltinDefinitions())
|
||||
|
||||
go func() {
|
||||
if err = bh.Start(); err != nil {
|
||||
logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
|
||||
@@ -174,50 +156,8 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
|
||||
currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
|
||||
Scope: tu.ScopeDefault(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get commands: %w", err)
|
||||
}
|
||||
|
||||
commands := []telego.BotCommand{
|
||||
{
|
||||
Command: "start",
|
||||
Description: "Start the bot",
|
||||
},
|
||||
{
|
||||
Command: "help",
|
||||
Description: "Show a help message",
|
||||
},
|
||||
{
|
||||
Command: "show",
|
||||
Description: "Show current configuration",
|
||||
},
|
||||
{
|
||||
Command: "list",
|
||||
Description: "List available options",
|
||||
},
|
||||
}
|
||||
|
||||
// Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
|
||||
if !slices.Equal(currentCommands, commands) {
|
||||
logger.InfoC("telegram", "Updating bot commands")
|
||||
|
||||
err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
|
||||
Commands: commands,
|
||||
Scope: tu.ScopeDefault(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set commands: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.DebugC("telegram", "Bot commands are up to date")
|
||||
if c.commandRegCancel != nil {
|
||||
c.commandRegCancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -233,22 +173,57 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
htmlContent := markdownToTelegramHTML(msg.Content)
|
||||
if msg.Content == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Typing/placeholder handled by Manager.preSend — just send the message
|
||||
// The Manager already splits messages to ≤4000 chars (WithMaxMessageLength),
|
||||
// so msg.Content is guaranteed to be within that limit. We still need to
|
||||
// check if HTML expansion pushes it beyond Telegram's 4096-char API limit.
|
||||
queue := []string{msg.Content}
|
||||
for len(queue) > 0 {
|
||||
chunk := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
htmlContent := markdownToTelegramHTML(chunk)
|
||||
|
||||
if len([]rune(htmlContent)) > 4096 {
|
||||
ratio := float64(len([]rune(chunk))) / float64(len([]rune(htmlContent)))
|
||||
smallerLen := int(float64(4096) * ratio * 0.95) // 5% safety margin
|
||||
if smallerLen < 100 {
|
||||
smallerLen = 100
|
||||
}
|
||||
// Push sub-chunks back to the front of the queue for
|
||||
// re-validation instead of sending them blindly.
|
||||
subChunks := channels.SplitMessage(chunk, smallerLen)
|
||||
queue = append(subChunks, queue...)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.sendHTMLChunk(ctx, chatID, htmlContent, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendHTMLChunk sends a single HTML message, falling back to the original
|
||||
// markdown as plain text on parse failure so users never see raw HTML tags.
|
||||
func (c *TelegramChannel) sendHTMLChunk(ctx context.Context, chatID int64, htmlContent, mdFallback string) error {
|
||||
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
|
||||
tgMsg.ParseMode = telego.ModeHTML
|
||||
|
||||
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
tgMsg.Text = mdFallback
|
||||
tgMsg.ParseMode = ""
|
||||
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
return fmt.Errorf("telegram send: %w", channels.ErrTemporary)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -721,34 +696,34 @@ func escapeHTML(text string) string {
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message via entities.
|
||||
func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
|
||||
botUsername := c.bot.Username()
|
||||
if botUsername == "" {
|
||||
text, entities := telegramEntityTextAndList(message)
|
||||
if text == "" || len(entities) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
entities := message.Entities
|
||||
if entities == nil {
|
||||
entities = message.CaptionEntities
|
||||
botUsername := ""
|
||||
if c.bot != nil {
|
||||
botUsername = c.bot.Username()
|
||||
}
|
||||
runes := []rune(text)
|
||||
|
||||
for _, entity := range entities {
|
||||
if entity.Type == "mention" {
|
||||
// Extract the mention text from the message
|
||||
text := message.Text
|
||||
if text == "" {
|
||||
text = message.Caption
|
||||
}
|
||||
runes := []rune(text)
|
||||
end := entity.Offset + entity.Length
|
||||
if end <= len(runes) {
|
||||
mention := string(runes[entity.Offset:end])
|
||||
if strings.EqualFold(mention, "@"+botUsername) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
entityText, ok := telegramEntityText(runes, entity)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if entity.Type == "text_mention" && entity.User != nil {
|
||||
if entity.User.Username == botUsername {
|
||||
|
||||
switch entity.Type {
|
||||
case telego.EntityTypeMention:
|
||||
if botUsername != "" && strings.EqualFold(entityText, "@"+botUsername) {
|
||||
return true
|
||||
}
|
||||
case telego.EntityTypeTextMention:
|
||||
if botUsername != "" && entity.User != nil && strings.EqualFold(entity.User.Username, botUsername) {
|
||||
return true
|
||||
}
|
||||
case telego.EntityTypeBotCommand:
|
||||
if isBotCommandEntityForThisBot(entityText, botUsername) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -756,6 +731,46 @@ func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func telegramEntityTextAndList(message *telego.Message) (string, []telego.MessageEntity) {
|
||||
if message.Text != "" {
|
||||
return message.Text, message.Entities
|
||||
}
|
||||
return message.Caption, message.CaptionEntities
|
||||
}
|
||||
|
||||
func telegramEntityText(runes []rune, entity telego.MessageEntity) (string, bool) {
|
||||
if entity.Offset < 0 || entity.Length <= 0 {
|
||||
return "", false
|
||||
}
|
||||
end := entity.Offset + entity.Length
|
||||
if entity.Offset >= len(runes) || end > len(runes) {
|
||||
return "", false
|
||||
}
|
||||
return string(runes[entity.Offset:end]), true
|
||||
}
|
||||
|
||||
func isBotCommandEntityForThisBot(entityText, botUsername string) bool {
|
||||
if !strings.HasPrefix(entityText, "/") {
|
||||
return false
|
||||
}
|
||||
command := strings.TrimPrefix(entityText, "/")
|
||||
if command == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
at := strings.IndexRune(command, '@')
|
||||
if at == -1 {
|
||||
// A bare /command delivered to this bot is intended for this bot.
|
||||
return true
|
||||
}
|
||||
|
||||
mentionUsername := command[at+1:]
|
||||
if mentionUsername == "" || botUsername == "" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(mentionUsername, botUsername)
|
||||
}
|
||||
|
||||
// stripBotMention removes the @bot mention from the content.
|
||||
func (c *TelegramChannel) stripBotMention(content string) string {
|
||||
botUsername := c.bot.Username()
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type TelegramCommander interface {
|
||||
Help(ctx context.Context, message telego.Message) error
|
||||
Start(ctx context.Context, message telego.Message) error
|
||||
Show(ctx context.Context, message telego.Message) error
|
||||
List(ctx context.Context, message telego.Message) error
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
bot *telego.Bot
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
|
||||
return &cmd{
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func commandArgs(text string) string {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
|
||||
msg := `/start - Start the bot
|
||||
/help - Show this help message
|
||||
/show [model|channel] - Show current configuration
|
||||
/list [models|channels] - List available options
|
||||
`
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: msg,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Hello! I am PicoClaw 🦞",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /show [model|channel]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "model":
|
||||
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
|
||||
c.config.Agents.Defaults.GetModelName(),
|
||||
c.config.Agents.Defaults.Provider)
|
||||
case "channel":
|
||||
response = "Current Channel: telegram"
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) List(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /list [models|channels]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "models":
|
||||
provider := c.config.Agents.Defaults.Provider
|
||||
if provider == "" {
|
||||
provider = "configured default"
|
||||
}
|
||||
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
|
||||
c.config.Agents.Defaults.GetModelName(), provider)
|
||||
|
||||
case "channels":
|
||||
var enabled []string
|
||||
if c.config.Channels.Telegram.Enabled {
|
||||
enabled = append(enabled, "telegram")
|
||||
}
|
||||
if c.config.Channels.WhatsApp.Enabled {
|
||||
enabled = append(enabled, "whatsapp")
|
||||
}
|
||||
if c.config.Channels.Feishu.Enabled {
|
||||
enabled = append(enabled, "feishu")
|
||||
}
|
||||
if c.config.Channels.Discord.Enabled {
|
||||
enabled = append(enabled, "discord")
|
||||
}
|
||||
if c.config.Channels.Slack.Enabled {
|
||||
enabled = append(enabled, "slack")
|
||||
}
|
||||
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
|
||||
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
|
||||
chatIDs: make(map[string]int64),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
msg := &telego.Message{
|
||||
Text: "/new",
|
||||
MessageID: 9,
|
||||
Chat: telego.Chat{
|
||||
ID: 123,
|
||||
Type: "private",
|
||||
},
|
||||
From: &telego.User{
|
||||
ID: 42,
|
||||
FirstName: "Alice",
|
||||
},
|
||||
}
|
||||
|
||||
if err := ch.handleMessage(context.Background(), msg); err != nil {
|
||||
t.Fatalf("handleMessage error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Channel != "telegram" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.Content != "/new" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
ta "github.com/mymmrac/telego/telegoapi"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type getMeCaller struct {
|
||||
username string
|
||||
}
|
||||
|
||||
func (c getMeCaller) Call(_ context.Context, url string, _ *ta.RequestData) (*ta.Response, error) {
|
||||
if strings.HasSuffix(url, "/getMe") {
|
||||
result := fmt.Sprintf(`{"id":1,"is_bot":true,"first_name":"bot","username":%q}`, c.username)
|
||||
return &ta.Response{Ok: true, Result: []byte(result)}, nil
|
||||
}
|
||||
return &ta.Response{Ok: true, Result: []byte("true")}, nil
|
||||
}
|
||||
|
||||
func newTestTelegramBot(t *testing.T, username string) *telego.Bot {
|
||||
t.Helper()
|
||||
|
||||
token := "123456:" + strings.Repeat("a", 35)
|
||||
bot, err := telego.NewBot(token,
|
||||
telego.WithAPICaller(getMeCaller{username: username}),
|
||||
telego.WithDiscardLogger(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewBot error: %v", err)
|
||||
}
|
||||
return bot
|
||||
}
|
||||
|
||||
func newGroupMentionOnlyChannel(t *testing.T, botUsername string) (*TelegramChannel, *bus.MessageBus) {
|
||||
t.Helper()
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil,
|
||||
channels.WithGroupTrigger(config.GroupTriggerConfig{MentionOnly: true}),
|
||||
),
|
||||
bot: newTestTelegramBot(t, botUsername),
|
||||
chatIDs: make(map[string]int64),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
return ch, messageBus
|
||||
}
|
||||
|
||||
func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
wantForwarded bool
|
||||
wantContent string
|
||||
}{
|
||||
{
|
||||
name: "command with bot username",
|
||||
text: "/new@testbot",
|
||||
wantForwarded: true,
|
||||
wantContent: "/new",
|
||||
},
|
||||
{
|
||||
name: "bare command",
|
||||
text: "/new",
|
||||
wantForwarded: true,
|
||||
wantContent: "/new",
|
||||
},
|
||||
{
|
||||
name: "command for another bot",
|
||||
text: "/new@otherbot",
|
||||
wantForwarded: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ch, messageBus := newGroupMentionOnlyChannel(t, "testbot")
|
||||
|
||||
msg := &telego.Message{
|
||||
Text: tc.text,
|
||||
Entities: []telego.MessageEntity{{
|
||||
Type: telego.EntityTypeBotCommand,
|
||||
Offset: 0,
|
||||
Length: len([]rune(tc.text)),
|
||||
}},
|
||||
MessageID: 42,
|
||||
Chat: telego.Chat{
|
||||
ID: 123,
|
||||
Type: "group",
|
||||
},
|
||||
From: &telego.User{
|
||||
ID: 7,
|
||||
FirstName: "Alice",
|
||||
},
|
||||
}
|
||||
|
||||
if err := ch.handleMessage(context.Background(), msg); err != nil {
|
||||
t.Fatalf("handleMessage error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if tc.wantForwarded {
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Content != tc.wantContent {
|
||||
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ok {
|
||||
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBotMentioned_MentionEntityUnaffected(t *testing.T) {
|
||||
ch, _ := newGroupMentionOnlyChannel(t, "testbot")
|
||||
|
||||
msg := &telego.Message{
|
||||
Text: "@testbot hello",
|
||||
Entities: []telego.MessageEntity{{
|
||||
Type: telego.EntityTypeMention,
|
||||
Offset: 0,
|
||||
Length: len("@testbot"),
|
||||
}},
|
||||
}
|
||||
|
||||
if !ch.isBotMentioned(msg) {
|
||||
t.Fatal("expected mention entity to be treated as bot mention")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
ta "github.com/mymmrac/telego/telegoapi"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
const testToken = "1234567890:aaaabbbbaaaabbbbaaaabbbbaaaabbbbccc"
|
||||
|
||||
// stubCaller implements ta.Caller for testing.
|
||||
type stubCaller struct {
|
||||
calls []stubCall
|
||||
callFn func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error)
|
||||
}
|
||||
|
||||
type stubCall struct {
|
||||
URL string
|
||||
Data *ta.RequestData
|
||||
}
|
||||
|
||||
func (s *stubCaller) Call(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
s.calls = append(s.calls, stubCall{URL: url, Data: data})
|
||||
return s.callFn(ctx, url, data)
|
||||
}
|
||||
|
||||
// stubConstructor implements ta.RequestConstructor for testing.
|
||||
type stubConstructor struct{}
|
||||
|
||||
func (s *stubConstructor) JSONRequest(parameters any) (*ta.RequestData, error) {
|
||||
return &ta.RequestData{}, nil
|
||||
}
|
||||
|
||||
func (s *stubConstructor) MultipartRequest(
|
||||
parameters map[string]string,
|
||||
files map[string]ta.NamedReader,
|
||||
) (*ta.RequestData, error) {
|
||||
return &ta.RequestData{}, nil
|
||||
}
|
||||
|
||||
// successResponse returns a ta.Response that telego will treat as a successful SendMessage.
|
||||
func successResponse(t *testing.T) *ta.Response {
|
||||
t.Helper()
|
||||
msg := &telego.Message{MessageID: 1}
|
||||
b, err := json.Marshal(msg)
|
||||
require.NoError(t, err)
|
||||
return &ta.Response{Ok: true, Result: b}
|
||||
}
|
||||
|
||||
// newTestChannel creates a TelegramChannel with a mocked bot for unit testing.
|
||||
func newTestChannel(t *testing.T, caller *stubCaller) *TelegramChannel {
|
||||
t.Helper()
|
||||
|
||||
bot, err := telego.NewBot(testToken,
|
||||
telego.WithAPICaller(caller),
|
||||
telego.WithRequestConstructor(&stubConstructor{}),
|
||||
telego.WithDiscardLogger(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
base := channels.NewBaseChannel("telegram", nil, nil, nil,
|
||||
channels.WithMaxMessageLength(4000),
|
||||
)
|
||||
base.SetRunning(true)
|
||||
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
bot: bot,
|
||||
chatIDs: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_EmptyContent(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
t.Fatal("SendMessage should not be called for empty content")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: "",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, caller.calls, "no API calls should be made for empty content")
|
||||
}
|
||||
|
||||
func TestSend_ShortMessage_SingleCall(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: "Hello, world!",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, caller.calls, 1, "short message should result in exactly one SendMessage call")
|
||||
}
|
||||
|
||||
func TestSend_LongMessage_SingleCall(t *testing.T) {
|
||||
// With WithMaxMessageLength(4000), the Manager pre-splits messages before
|
||||
// they reach Send(). A message at exactly 4000 chars should go through
|
||||
// as a single SendMessage call (no re-split needed since HTML expansion
|
||||
// won't exceed 4096 for plain text).
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
longContent := strings.Repeat("a", 4000)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: longContent,
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, caller.calls, 1, "pre-split message within limit should result in one SendMessage call")
|
||||
}
|
||||
|
||||
func TestSend_HTMLFallback_PerChunk(t *testing.T) {
|
||||
callCount := 0
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
callCount++
|
||||
// Fail on odd calls (HTML attempt), succeed on even calls (plain text fallback)
|
||||
if callCount%2 == 1 {
|
||||
return nil, errors.New("Bad Request: can't parse entities")
|
||||
}
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: "Hello **world**",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
// One short message → 1 HTML attempt (fail) + 1 plain text fallback (success) = 2 calls
|
||||
assert.Equal(t, 2, len(caller.calls), "should have HTML attempt + plain text fallback")
|
||||
}
|
||||
|
||||
func TestSend_HTMLFallback_BothFail(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return nil, errors.New("send failed")
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: "Hello",
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, channels.ErrTemporary), "error should wrap ErrTemporary")
|
||||
assert.Equal(t, 2, len(caller.calls), "should have HTML attempt + plain text attempt")
|
||||
}
|
||||
|
||||
func TestSend_LongMessage_HTMLFallback_StopsOnError(t *testing.T) {
|
||||
// With a long message that gets split into 2 chunks, if both HTML and
|
||||
// plain text fail on the first chunk, Send should return early.
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return nil, errors.New("send failed")
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
longContent := strings.Repeat("x", 4001)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: longContent,
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
// Should fail on the first chunk (2 calls: HTML + fallback), never reaching the second chunk.
|
||||
assert.Equal(t, 2, len(caller.calls), "should stop after first chunk fails both HTML and plain text")
|
||||
}
|
||||
|
||||
func TestSend_MarkdownShortButHTMLLong_MultipleCalls(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
return successResponse(t), nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
// Create markdown whose length is <= 4000 but whose HTML expansion is much longer.
|
||||
// "**a** " (6 chars) becomes "<b>a</b> " (9 chars) in HTML, so repeating it many times
|
||||
// yields HTML that exceeds Telegram's limit while markdown stays within it.
|
||||
markdownContent := strings.Repeat("**a** ", 600) // 3600 chars markdown, HTML ~5400+ chars
|
||||
assert.LessOrEqual(t, len([]rune(markdownContent)), 4000, "markdown content must not exceed chunk size")
|
||||
|
||||
htmlExpanded := markdownToTelegramHTML(markdownContent)
|
||||
assert.Greater(
|
||||
t, len([]rune(htmlExpanded)), 4096,
|
||||
"HTML expansion must exceed Telegram limit for this test to be meaningful",
|
||||
)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: markdownContent,
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(
|
||||
t, len(caller.calls), 1,
|
||||
"markdown-short but HTML-long message should be split into multiple SendMessage calls",
|
||||
)
|
||||
}
|
||||
|
||||
func TestSend_NotRunning(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
t.Fatal("should not be called")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
ch.SetRunning(false)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "12345",
|
||||
Content: "Hello",
|
||||
})
|
||||
|
||||
assert.ErrorIs(t, err, channels.ErrNotRunning)
|
||||
assert.Empty(t, caller.calls)
|
||||
}
|
||||
|
||||
func TestSend_InvalidChatID(t *testing.T) {
|
||||
caller := &stubCaller{
|
||||
callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) {
|
||||
t.Fatal("should not be called")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
ch := newTestChannel(t, caller)
|
||||
|
||||
err := ch.Send(context.Background(), bus.OutboundMessage{
|
||||
ChatID: "not-a-number",
|
||||
Content: "Hello",
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, channels.ErrSendFailed), "error should wrap ErrSendFailed")
|
||||
assert.Empty(t, caller.calls)
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package whatsapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &WhatsAppChannel{
|
||||
BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
ch.handleIncomingMessage(map[string]any{
|
||||
"type": "message",
|
||||
"id": "mid1",
|
||||
"from": "user1",
|
||||
"chat": "chat1",
|
||||
"content": "/help",
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Channel != "whatsapp" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.Content != "/help" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
//go:build whatsapp_native
|
||||
|
||||
package whatsapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/whatsmeow/proto/waE2E"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"go.mau.fi/whatsmeow/types/events"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &WhatsAppNativeChannel{
|
||||
BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil),
|
||||
runCtx: context.Background(),
|
||||
}
|
||||
|
||||
evt := &events.Message{
|
||||
Info: types.MessageInfo{
|
||||
MessageSource: types.MessageSource{
|
||||
Sender: types.NewJID("1001", types.DefaultUserServer),
|
||||
Chat: types.NewJID("1001", types.DefaultUserServer),
|
||||
},
|
||||
ID: "mid1",
|
||||
PushName: "Alice",
|
||||
},
|
||||
Message: &waE2E.Message{
|
||||
Conversation: proto.String("/new"),
|
||||
},
|
||||
}
|
||||
|
||||
ch.handleIncoming(evt)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Channel != "whatsapp_native" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.Content != "/new" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package commands
|
||||
|
||||
// BuiltinDefinitions returns all built-in command definitions.
|
||||
// Each command group is defined in its own cmd_*.go file.
|
||||
// Definitions are stateless — runtime dependencies are provided
|
||||
// via the Runtime parameter passed to handlers at execution time.
|
||||
func BuiltinDefinitions() []Definition {
|
||||
return []Definition{
|
||||
startCommand(),
|
||||
helpCommand(),
|
||||
showCommand(),
|
||||
listCommand(),
|
||||
switchCommand(),
|
||||
checkCommand(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func findDefinitionByName(t *testing.T, defs []Definition, name string) Definition {
|
||||
t.Helper()
|
||||
for _, def := range defs {
|
||||
if def.Name == name {
|
||||
return def
|
||||
}
|
||||
}
|
||||
t.Fatalf("missing /%s definition", name)
|
||||
return Definition{}
|
||||
}
|
||||
|
||||
func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
|
||||
defs := BuiltinDefinitions()
|
||||
helpDef := findDefinitionByName(t, defs, "help")
|
||||
if helpDef.Handler == nil {
|
||||
t.Fatalf("/help handler should not be nil")
|
||||
}
|
||||
|
||||
var reply string
|
||||
err := helpDef.Handler(context.Background(), Request{
|
||||
Text: "/help",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("/help handler error: %v", err)
|
||||
}
|
||||
// Now uses auto-generated EffectiveUsage which includes agents
|
||||
if !strings.Contains(reply, "/show [model|channel|agents]") {
|
||||
t.Fatalf("/help reply missing /show usage, got %q", reply)
|
||||
}
|
||||
if !strings.Contains(reply, "/list [models|channels|agents]") {
|
||||
t.Fatalf("/help reply missing /list usage, got %q", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) {
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
cases := []string{"telegram", "whatsapp"}
|
||||
for _, channel := range cases {
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Channel: channel,
|
||||
Text: "/show channel",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/show channel on %s: outcome=%v, want=%v", channel, res.Outcome, OutcomeHandled)
|
||||
}
|
||||
want := "Current Channel: " + channel
|
||||
if reply != want {
|
||||
t.Fatalf("/show channel reply=%q, want=%q", reply, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinListChannels_UsesGetEnabledChannels(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
GetEnabledChannels: func() []string {
|
||||
return []string{"telegram", "slack"}
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/list channels",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/list channels: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !strings.Contains(reply, "telegram") || !strings.Contains(reply, "slack") {
|
||||
t.Fatalf("/list channels reply=%q, want telegram and slack", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinShowAgents_RestoresOldBehavior(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
ListAgentIDs: func() []string {
|
||||
return []string{"default", "coder"}
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/show agents",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/show agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") {
|
||||
t.Fatalf("/show agents reply=%q, want agent IDs", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinListAgents_RestoresOldBehavior(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
ListAgentIDs: func() []string {
|
||||
return []string{"default", "coder"}
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/list agents",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/list agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") {
|
||||
t.Fatalf("/list agents reply=%q, want agent IDs", reply)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func checkCommand() Definition {
|
||||
return Definition{
|
||||
Name: "check",
|
||||
Description: "Check channel availability",
|
||||
SubCommands: []SubCommand{
|
||||
{
|
||||
Name: "channel",
|
||||
Description: "Check if a channel is available",
|
||||
ArgsUsage: "<name>",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.SwitchChannel == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
value := nthToken(req.Text, 2)
|
||||
if value == "" {
|
||||
return req.Reply("Usage: /check channel <name>")
|
||||
}
|
||||
if err := rt.SwitchChannel(value); err != nil {
|
||||
return req.Reply(err.Error())
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("Channel '%s' is available and enabled", value))
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func helpCommand() Definition {
|
||||
return Definition{
|
||||
Name: "help",
|
||||
Description: "Show this help message",
|
||||
Usage: "/help",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
var defs []Definition
|
||||
if rt != nil && rt.ListDefinitions != nil {
|
||||
defs = rt.ListDefinitions()
|
||||
} else {
|
||||
defs = BuiltinDefinitions()
|
||||
}
|
||||
return req.Reply(formatHelpMessage(defs))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func formatHelpMessage(defs []Definition) string {
|
||||
if len(defs) == 0 {
|
||||
return "No commands available."
|
||||
}
|
||||
|
||||
lines := make([]string, 0, len(defs))
|
||||
for _, def := range defs {
|
||||
usage := def.EffectiveUsage()
|
||||
if usage == "" {
|
||||
usage = "/" + def.Name
|
||||
}
|
||||
desc := def.Description
|
||||
if desc == "" {
|
||||
desc = "No description"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%s - %s", usage, desc))
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func listCommand() Definition {
|
||||
return Definition{
|
||||
Name: "list",
|
||||
Description: "List available options",
|
||||
SubCommands: []SubCommand{
|
||||
{
|
||||
Name: "models",
|
||||
Description: "Configured models",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.GetModelInfo == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
name, provider := rt.GetModelInfo()
|
||||
if provider == "" {
|
||||
provider = "configured default"
|
||||
}
|
||||
return req.Reply(fmt.Sprintf(
|
||||
"Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
|
||||
name, provider,
|
||||
))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "channels",
|
||||
Description: "Enabled channels",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.GetEnabledChannels == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
enabled := rt.GetEnabledChannels()
|
||||
if len(enabled) == 0 {
|
||||
return req.Reply("No channels enabled")
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "agents",
|
||||
Description: "Registered agents",
|
||||
Handler: agentsHandler(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func showCommand() Definition {
|
||||
return Definition{
|
||||
Name: "show",
|
||||
Description: "Show current configuration",
|
||||
SubCommands: []SubCommand{
|
||||
{
|
||||
Name: "model",
|
||||
Description: "Current model and provider",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.GetModelInfo == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
name, provider := rt.GetModelInfo()
|
||||
return req.Reply(fmt.Sprintf("Current Model: %s (Provider: %s)", name, provider))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "channel",
|
||||
Description: "Current channel",
|
||||
Handler: func(_ context.Context, req Request, _ *Runtime) error {
|
||||
return req.Reply(fmt.Sprintf("Current Channel: %s", req.Channel))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "agents",
|
||||
Description: "Registered agents",
|
||||
Handler: agentsHandler(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package commands
|
||||
|
||||
import "context"
|
||||
|
||||
func startCommand() Definition {
|
||||
return Definition{
|
||||
Name: "start",
|
||||
Description: "Start the bot",
|
||||
Usage: "/start",
|
||||
Handler: func(_ context.Context, req Request, _ *Runtime) error {
|
||||
return req.Reply("Hello! I am PicoClaw 🦞")
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func switchCommand() Definition {
|
||||
return Definition{
|
||||
Name: "switch",
|
||||
Description: "Switch model",
|
||||
SubCommands: []SubCommand{
|
||||
{
|
||||
Name: "model",
|
||||
Description: "Switch to a different model",
|
||||
ArgsUsage: "to <name>",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.SwitchModel == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
// Parse: /switch model to <value>
|
||||
value := nthToken(req.Text, 3) // tokens: [/switch, model, to, <value>]
|
||||
if nthToken(req.Text, 2) != "to" || value == "" {
|
||||
return req.Reply("Usage: /switch model to <name>")
|
||||
}
|
||||
oldModel, err := rt.SwitchModel(value)
|
||||
if err != nil {
|
||||
return req.Reply(err.Error())
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("Switched model from %s to %s", oldModel, value))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "channel",
|
||||
Description: "Moved to /check channel",
|
||||
Handler: func(_ context.Context, req Request, _ *Runtime) error {
|
||||
return req.Reply("This command has moved. Please use: /check channel <name>")
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSwitchModel_Success(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "old-model", nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model to gpt-4",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
want := "Switched model from old-model to gpt-4"
|
||||
if reply != want {
|
||||
t.Fatalf("reply=%q, want=%q", reply, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchModel_MissingToKeyword(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "old", nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model gpt-4",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Usage: /switch model to <name>" {
|
||||
t.Fatalf("reply=%q, want usage message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchModel_MissingValue(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "old", nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model to",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Usage: /switch model to <name>" {
|
||||
t.Fatalf("reply=%q, want usage message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchModel_Error(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "", fmt.Errorf("model not found")
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model to bad-model",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "model not found" {
|
||||
t.Fatalf("reply=%q, want error message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchModel_NilDep(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model to gpt-4",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Command unavailable in current context." {
|
||||
t.Fatalf("reply=%q, want unavailable message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchChannel_Redirect(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch channel to telegram",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
want := "This command has moved. Please use: /check channel <name>"
|
||||
if reply != want {
|
||||
t.Fatalf("reply=%q, want=%q", reply, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckChannel_Success(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchChannel: func(value string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/check channel telegram",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
want := "Channel 'telegram' is available and enabled"
|
||||
if reply != want {
|
||||
t.Fatalf("reply=%q, want=%q", reply, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckChannel_Error(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchChannel: func(value string) error {
|
||||
return fmt.Errorf("channel '%s' not found", value)
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/check channel unknown",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "channel 'unknown' not found" {
|
||||
t.Fatalf("reply=%q, want error message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckChannel_NilDep(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/check channel telegram",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Command unavailable in current context." {
|
||||
t.Fatalf("reply=%q, want unavailable message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckChannel_MissingValue(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchChannel: func(value string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/check channel",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Usage: /check channel <name>" {
|
||||
t.Fatalf("reply=%q, want usage message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitch_BangPrefix(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "old", nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "!switch model to gpt-4",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("! prefix: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Switched model from old to gpt-4" {
|
||||
t.Fatalf("! prefix: reply=%q, want success message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitch_NoSubCommand(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
// Should get usage message from executor's sub-command routing
|
||||
if reply == "" {
|
||||
t.Fatal("expected usage reply for bare /switch")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SubCommand defines a single sub-command within a parent command.
|
||||
type SubCommand struct {
|
||||
Name string
|
||||
Description string
|
||||
ArgsUsage string // optional, e.g. "<session-id>"
|
||||
Handler Handler
|
||||
}
|
||||
|
||||
// Definition is the single-source metadata and behavior contract for a slash command.
|
||||
//
|
||||
// Design notes (phase 1):
|
||||
// - Every channel reads command shape from this type instead of keeping local copies.
|
||||
// - Visibility is global: all definitions are considered available to all channels.
|
||||
// - Platform menu registration (for example Telegram BotCommand) also derives from this
|
||||
// same definition so UI labels and runtime behavior stay aligned.
|
||||
type Definition struct {
|
||||
Name string
|
||||
Description string
|
||||
Usage string // for simple commands; ignored when SubCommands is set
|
||||
Aliases []string
|
||||
SubCommands []SubCommand // optional; when set, Executor routes to sub-command handlers
|
||||
Handler Handler // for simple commands without sub-commands
|
||||
}
|
||||
|
||||
// EffectiveUsage returns the usage string. When SubCommands are present,
|
||||
// it is auto-generated from sub-command names so metadata and behavior
|
||||
// cannot drift.
|
||||
func (d Definition) EffectiveUsage() string {
|
||||
if len(d.SubCommands) == 0 {
|
||||
return d.Usage
|
||||
}
|
||||
names := make([]string, 0, len(d.SubCommands))
|
||||
for _, sc := range d.SubCommands {
|
||||
name := sc.Name
|
||||
if sc.ArgsUsage != "" {
|
||||
name += " " + sc.ArgsUsage
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return fmt.Sprintf("/%s [%s]", d.Name, strings.Join(names, "|"))
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefinition_EffectiveUsage_NoSubCommands(t *testing.T) {
|
||||
d := Definition{Name: "start", Usage: "/start"}
|
||||
if got := d.EffectiveUsage(); got != "/start" {
|
||||
t.Fatalf("EffectiveUsage()=%q, want %q", got, "/start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefinition_EffectiveUsage_WithSubCommands(t *testing.T) {
|
||||
d := Definition{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model"},
|
||||
{Name: "channel"},
|
||||
{Name: "agents"},
|
||||
},
|
||||
}
|
||||
want := "/show [model|channel|agents]"
|
||||
if got := d.EffectiveUsage(); got != want {
|
||||
t.Fatalf("EffectiveUsage()=%q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefinition_EffectiveUsage_WithArgsUsage(t *testing.T) {
|
||||
d := Definition{
|
||||
Name: "session",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "list"},
|
||||
{Name: "resume", ArgsUsage: "<id>"},
|
||||
},
|
||||
}
|
||||
want := "/session [list|resume <id>]"
|
||||
if got := d.EffectiveUsage(); got != want {
|
||||
t.Fatalf("EffectiveUsage()=%q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Outcome int
|
||||
|
||||
const (
|
||||
// OutcomePassthrough means this input should continue through normal agent flow.
|
||||
OutcomePassthrough Outcome = iota
|
||||
// OutcomeHandled means a command handler executed (with or without handler error).
|
||||
OutcomeHandled
|
||||
)
|
||||
|
||||
type ExecuteResult struct {
|
||||
Outcome Outcome
|
||||
Command string
|
||||
Err error
|
||||
}
|
||||
|
||||
type Executor struct {
|
||||
reg *Registry
|
||||
rt *Runtime
|
||||
}
|
||||
|
||||
func NewExecutor(reg *Registry, rt *Runtime) *Executor {
|
||||
return &Executor{reg: reg, rt: rt}
|
||||
}
|
||||
|
||||
// Execute implements a two-state command decision:
|
||||
// 1) handled: execute command immediately;
|
||||
// 2) passthrough: not a command or intentionally deferred to agent logic.
|
||||
func (e *Executor) Execute(ctx context.Context, req Request) ExecuteResult {
|
||||
cmdName, ok := parseCommandName(req.Text)
|
||||
if !ok {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough}
|
||||
}
|
||||
|
||||
if e == nil || e.reg == nil {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName}
|
||||
}
|
||||
|
||||
def, found := e.reg.Lookup(cmdName)
|
||||
if !found {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName}
|
||||
}
|
||||
|
||||
return e.executeDefinition(ctx, req, def)
|
||||
}
|
||||
|
||||
func (e *Executor) executeDefinition(ctx context.Context, req Request, def Definition) ExecuteResult {
|
||||
// Ensure Reply is always non-nil so handlers don't need to check.
|
||||
if req.Reply == nil {
|
||||
req.Reply = func(string) error { return nil }
|
||||
}
|
||||
|
||||
// Simple command — no sub-commands
|
||||
if len(def.SubCommands) == 0 {
|
||||
if def.Handler == nil {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name}
|
||||
}
|
||||
err := def.Handler(ctx, req, e.rt)
|
||||
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
|
||||
}
|
||||
|
||||
// Sub-command routing
|
||||
subName := nthToken(req.Text, 1)
|
||||
if subName == "" {
|
||||
err := req.Reply("Usage: " + def.EffectiveUsage())
|
||||
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
|
||||
}
|
||||
|
||||
normalized := normalizeCommandName(subName)
|
||||
for _, sc := range def.SubCommands {
|
||||
if normalizeCommandName(sc.Name) == normalized {
|
||||
if sc.Handler == nil {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name}
|
||||
}
|
||||
err := sc.Handler(ctx, req, e.rt)
|
||||
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// Unknown sub-command
|
||||
err := req.Reply(fmt.Sprintf("Unknown option: %s. Usage: %s", subName, def.EffectiveUsage()))
|
||||
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExecutor_RegisteredWithoutHandler_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{{Name: "show"}}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/show"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_UnknownSlashCommand_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{{Name: "show"}}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/unknown"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SupportedCommandWithHandler_ReturnsHandled(t *testing.T) {
|
||||
called := false
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "help",
|
||||
Handler: func(context.Context, Request, *Runtime) error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help@my_bot"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !called {
|
||||
t.Fatalf("expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_AliasWithoutHandler_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
Aliases: []string{"display"},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/display"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
if res.Command != "show" {
|
||||
t.Fatalf("command=%q, want=%q", res.Command, "show")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_AliasWithHandler_ReturnsHandled(t *testing.T) {
|
||||
called := false
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "clear",
|
||||
Aliases: []string{"reset"},
|
||||
Handler: func(context.Context, Request, *Runtime) error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/reset"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if res.Command != "clear" {
|
||||
t.Fatalf("command=%q, want=%q", res.Command, "clear")
|
||||
}
|
||||
if !called {
|
||||
t.Fatalf("expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SupportedCommandWithNilHandler_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{Name: "placeholder"},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder list"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
if res.Command != "placeholder" {
|
||||
t.Fatalf("command=%q, want=%q", res.Command, "placeholder")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_NilHandlerDoesNotMaskLaterHandler(t *testing.T) {
|
||||
// With Lookup-based dispatch, the first registered definition for a name wins.
|
||||
// A definition with nil Handler and no SubCommands returns Passthrough.
|
||||
defs := []Definition{
|
||||
{Name: "placeholder"},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
if res.Command != "placeholder" {
|
||||
t.Fatalf("command=%q, want=%q", res.Command, "placeholder")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_HandlerErrorIsPropagated(t *testing.T) {
|
||||
wantErr := errors.New("handler failed")
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "help",
|
||||
Handler: func(context.Context, Request, *Runtime) error {
|
||||
return wantErr
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !errors.Is(res.Err, wantErr) {
|
||||
t.Fatalf("err=%v, want=%v", res.Err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SupportsBangPrefixAndCaseInsensitiveCommand(t *testing.T) {
|
||||
called := false
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "help",
|
||||
Handler: func(context.Context, Request, *Runtime) error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "!HELP"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !called {
|
||||
t.Fatalf("expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SubCommand_RoutesToCorrectHandler(t *testing.T) {
|
||||
modelCalled := false
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model", Handler: func(_ context.Context, _ Request, _ *Runtime) error {
|
||||
modelCalled = true
|
||||
return nil
|
||||
}},
|
||||
{Name: "channel"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Text: "/show model"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !modelCalled {
|
||||
t.Fatal("model sub-command handler was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SubCommand_NoArg_RepliesUsage(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model"},
|
||||
{Name: "channel"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/show",
|
||||
Reply: func(text string) error { reply = text; return nil },
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Usage: /show [model|channel]" {
|
||||
t.Fatalf("reply=%q, want usage message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SubCommand_UnknownArg_RepliesError(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/show foobar",
|
||||
Reply: func(text string) error { reply = text; return nil },
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !strings.Contains(reply, "foobar") {
|
||||
t.Fatalf("reply=%q, should mention unknown sub-command", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SubCommand_NilHandler_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model"}, // nil Handler
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Text: "/show model"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// agentsHandler returns a shared handler for both /show agents and /list agents.
|
||||
func agentsHandler() Handler {
|
||||
return func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.ListAgentIDs == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
ids := rt.ListAgentIDs()
|
||||
if len(ids) == 0 {
|
||||
return req.Reply("No agents registered")
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("Registered agents: %s", strings.Join(ids, ", ")))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package commands
|
||||
|
||||
type Registry struct {
|
||||
defs []Definition
|
||||
index map[string]int
|
||||
}
|
||||
|
||||
// NewRegistry stores the canonical command set used by both dispatch and
|
||||
// optional platform registration adapters.
|
||||
func NewRegistry(defs []Definition) *Registry {
|
||||
stored := make([]Definition, len(defs))
|
||||
copy(stored, defs)
|
||||
|
||||
index := make(map[string]int, len(stored)*2)
|
||||
for i, def := range stored {
|
||||
registerCommandName(index, def.Name, i)
|
||||
for _, alias := range def.Aliases {
|
||||
registerCommandName(index, alias, i)
|
||||
}
|
||||
}
|
||||
|
||||
return &Registry{defs: stored, index: index}
|
||||
}
|
||||
|
||||
// Definitions returns all registered command definitions.
|
||||
// Command availability is global and no longer channel-scoped.
|
||||
func (r *Registry) Definitions() []Definition {
|
||||
out := make([]Definition, len(r.defs))
|
||||
copy(out, r.defs)
|
||||
return out
|
||||
}
|
||||
|
||||
// Lookup returns a command definition by normalized command name or alias.
|
||||
func (r *Registry) Lookup(name string) (Definition, bool) {
|
||||
key := normalizeCommandName(name)
|
||||
if key == "" {
|
||||
return Definition{}, false
|
||||
}
|
||||
idx, ok := r.index[key]
|
||||
if !ok {
|
||||
return Definition{}, false
|
||||
}
|
||||
return r.defs[idx], true
|
||||
}
|
||||
|
||||
func registerCommandName(index map[string]int, name string, defIndex int) {
|
||||
key := normalizeCommandName(name)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := index[key]; exists {
|
||||
return
|
||||
}
|
||||
index[key] = defIndex
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package commands
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRegistry_Definitions_ReturnsCopy(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{Name: "help", Description: "Show help"},
|
||||
{Name: "admin", Description: "Admin command"},
|
||||
}
|
||||
r := NewRegistry(defs)
|
||||
|
||||
got := r.Definitions()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("definitions len = %d, want 2", len(got))
|
||||
}
|
||||
|
||||
got[0].Name = "mutated"
|
||||
again := r.Definitions()
|
||||
if again[0].Name != "help" {
|
||||
t.Fatalf("registry should not be mutated by caller, got first name %q", again[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Lookup_MatchesByLowercaseNameAndAlias(t *testing.T) {
|
||||
r := NewRegistry([]Definition{
|
||||
{Name: "Help", Aliases: []string{"Assist"}},
|
||||
{Name: "List"},
|
||||
})
|
||||
|
||||
def, ok := r.Lookup("help")
|
||||
if !ok || def.Name != "Help" {
|
||||
t.Fatalf("lookup by lowercase name failed: ok=%v def=%+v", ok, def)
|
||||
}
|
||||
|
||||
def, ok = r.Lookup("HELP")
|
||||
if !ok || def.Name != "Help" {
|
||||
t.Fatalf("lookup by uppercase name failed: ok=%v def=%+v", ok, def)
|
||||
}
|
||||
|
||||
def, ok = r.Lookup("assist")
|
||||
if !ok || def.Name != "Help" {
|
||||
t.Fatalf("lookup by lowercase alias failed: ok=%v def=%+v", ok, def)
|
||||
}
|
||||
|
||||
def, ok = r.Lookup("ASSIST")
|
||||
if !ok || def.Name != "Help" {
|
||||
t.Fatalf("lookup by uppercase alias failed: ok=%v def=%+v", ok, def)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Handler func(ctx context.Context, req Request, rt *Runtime) error
|
||||
|
||||
type Request struct {
|
||||
Channel string
|
||||
ChatID string
|
||||
SenderID string
|
||||
Text string
|
||||
Reply func(text string) error
|
||||
}
|
||||
|
||||
const unavailableMsg = "Command unavailable in current context."
|
||||
|
||||
var commandPrefixes = []string{"/", "!"}
|
||||
|
||||
// parseCommandName accepts "/name", "!name", and Telegram's "/name@bot", then
|
||||
// normalizes to lowercase command names.
|
||||
func parseCommandName(input string) (string, bool) {
|
||||
token := nthToken(input, 0)
|
||||
if token == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
name, ok := trimCommandPrefix(token)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if i := strings.Index(name, "@"); i >= 0 {
|
||||
name = name[:i]
|
||||
}
|
||||
name = normalizeCommandName(name)
|
||||
if name == "" {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
func trimCommandPrefix(token string) (string, bool) {
|
||||
for _, prefix := range commandPrefixes {
|
||||
if strings.HasPrefix(token, prefix) {
|
||||
return strings.TrimPrefix(token, prefix), true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// HasCommandPrefix returns true if the input starts with a recognized
|
||||
// command prefix (e.g. "/" or "!").
|
||||
func HasCommandPrefix(input string) bool {
|
||||
token := nthToken(input, 0)
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
_, ok := trimCommandPrefix(token)
|
||||
return ok
|
||||
}
|
||||
|
||||
// nthToken returns the 0-indexed token from whitespace-split input.
|
||||
func nthToken(input string, n int) string {
|
||||
parts := strings.Fields(strings.TrimSpace(input))
|
||||
if n >= len(parts) {
|
||||
return ""
|
||||
}
|
||||
return parts[n]
|
||||
}
|
||||
|
||||
func normalizeCommandName(name string) string {
|
||||
return strings.ToLower(strings.TrimSpace(name))
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package commands
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHasCommandPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"/help", true},
|
||||
{"!help", true},
|
||||
{"/switch model to gpt-4", true},
|
||||
{"!switch model to gpt-4", true},
|
||||
{"hello", false},
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{"hello /world", false},
|
||||
{"/", true},
|
||||
{"!", true},
|
||||
{" /help", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := HasCommandPrefix(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("HasCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package commands
|
||||
|
||||
import "github.com/sipeed/picoclaw/pkg/config"
|
||||
|
||||
// Runtime provides runtime dependencies to command handlers. It is constructed
|
||||
// per-request by the agent loop so that per-request state (like session scope)
|
||||
// can coexist with long-lived callbacks (like GetModelInfo).
|
||||
type Runtime struct {
|
||||
Config *config.Config
|
||||
GetModelInfo func() (name, provider string)
|
||||
ListAgentIDs func() []string
|
||||
ListDefinitions func() []Definition
|
||||
GetEnabledChannels func() []string
|
||||
SwitchModel func(value string) (oldModel string, err error)
|
||||
SwitchChannel func(value string) error
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShowListHandlers_ChannelPolicy(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), nil)
|
||||
|
||||
var telegramReply string
|
||||
handled := ex.Execute(context.Background(), Request{
|
||||
Channel: "telegram",
|
||||
Text: "/show channel",
|
||||
Reply: func(text string) error {
|
||||
telegramReply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if handled.Outcome != OutcomeHandled {
|
||||
t.Fatalf("telegram /show outcome=%v, want=%v", handled.Outcome, OutcomeHandled)
|
||||
}
|
||||
if telegramReply != "Current Channel: telegram" {
|
||||
t.Fatalf("telegram /show reply=%q, want=%q", telegramReply, "Current Channel: telegram")
|
||||
}
|
||||
|
||||
var whatsappReply string
|
||||
handledWhatsApp := ex.Execute(context.Background(), Request{
|
||||
Channel: "whatsapp",
|
||||
Text: "/show channel",
|
||||
Reply: func(text string) error {
|
||||
whatsappReply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if handledWhatsApp.Outcome != OutcomeHandled {
|
||||
t.Fatalf("whatsapp /show outcome=%v, want=%v", handledWhatsApp.Outcome, OutcomeHandled)
|
||||
}
|
||||
if handledWhatsApp.Command != "show" {
|
||||
t.Fatalf("whatsapp /show command=%q, want=%q", handledWhatsApp.Command, "show")
|
||||
}
|
||||
if whatsappReply != "Current Channel: whatsapp" {
|
||||
t.Fatalf("whatsapp /show reply=%q, want=%q", whatsappReply, "Current Channel: whatsapp")
|
||||
}
|
||||
|
||||
passthrough := ex.Execute(context.Background(), Request{
|
||||
Channel: "whatsapp",
|
||||
Text: "/foo",
|
||||
})
|
||||
if passthrough.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("whatsapp /foo outcome=%v, want=%v", passthrough.Outcome, OutcomePassthrough)
|
||||
}
|
||||
if passthrough.Command != "foo" {
|
||||
t.Fatalf("whatsapp /foo command=%q, want=%q", passthrough.Command, "foo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowListHandlers_ListHandledOnAllChannels(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
GetEnabledChannels: func() []string {
|
||||
return []string{"telegram"}
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Channel: "whatsapp",
|
||||
Text: "/list channels",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("whatsapp /list outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if res.Command != "list" {
|
||||
t.Fatalf("whatsapp /list command=%q, want=%q", res.Command, "list")
|
||||
}
|
||||
if !strings.Contains(reply, "telegram") {
|
||||
t.Fatalf("whatsapp /list reply=%q, expected enabled channels content", reply)
|
||||
}
|
||||
}
|
||||
+54
-15
@@ -167,22 +167,35 @@ type SessionConfig struct {
|
||||
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
|
||||
}
|
||||
|
||||
// RoutingConfig controls the intelligent model routing feature.
|
||||
// When enabled, each incoming message is scored against structural features
|
||||
// (message length, code blocks, tool call history, conversation depth, attachments).
|
||||
// Messages scoring below Threshold are sent to LightModel; all others use the
|
||||
// agent's primary model. This reduces cost and latency for simple tasks without
|
||||
// requiring any keyword matching — all scoring is language-agnostic.
|
||||
type RoutingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
|
||||
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
@@ -218,6 +231,7 @@ type ChannelsConfig struct {
|
||||
WeComApp WeComAppConfig `json:"wecom_app"`
|
||||
WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
|
||||
Pico PicoConfig `json:"pico"`
|
||||
IRC IRCConfig `json:"irc"`
|
||||
}
|
||||
|
||||
// GroupTriggerConfig controls when the bot responds in group chats.
|
||||
@@ -402,6 +416,25 @@ type PicoConfig struct {
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
}
|
||||
|
||||
type IRCConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_IRC_ENABLED"`
|
||||
Server string `json:"server" env:"PICOCLAW_CHANNELS_IRC_SERVER"`
|
||||
TLS bool `json:"tls" env:"PICOCLAW_CHANNELS_IRC_TLS"`
|
||||
Nick string `json:"nick" env:"PICOCLAW_CHANNELS_IRC_NICK"`
|
||||
User string `json:"user,omitempty" env:"PICOCLAW_CHANNELS_IRC_USER"`
|
||||
RealName string `json:"real_name,omitempty" env:"PICOCLAW_CHANNELS_IRC_REAL_NAME"`
|
||||
Password string `json:"password" env:"PICOCLAW_CHANNELS_IRC_PASSWORD"`
|
||||
NickServPassword string `json:"nickserv_password" env:"PICOCLAW_CHANNELS_IRC_NICKSERV_PASSWORD"`
|
||||
SASLUser string `json:"sasl_user" env:"PICOCLAW_CHANNELS_IRC_SASL_USER"`
|
||||
SASLPassword string `json:"sasl_password" env:"PICOCLAW_CHANNELS_IRC_SASL_PASSWORD"`
|
||||
Channels FlexibleStringSlice `json:"channels" env:"PICOCLAW_CHANNELS_IRC_CHANNELS"`
|
||||
RequestCaps FlexibleStringSlice `json:"request_caps,omitempty" env:"PICOCLAW_CHANNELS_IRC_REQUEST_CAPS"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_IRC_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Typing TypingConfig `json:"typing,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_IRC_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
|
||||
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
|
||||
@@ -427,6 +460,7 @@ type ProvidersConfig struct {
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
Cerebras ProviderConfig `json:"cerebras"`
|
||||
Vivgrid ProviderConfig `json:"vivgrid"`
|
||||
VolcEngine ProviderConfig `json:"volcengine"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
Antigravity ProviderConfig `json:"antigravity"`
|
||||
@@ -452,6 +486,7 @@ func (p ProvidersConfig) IsEmpty() bool {
|
||||
p.ShengSuanYun.APIKey == "" && p.ShengSuanYun.APIBase == "" &&
|
||||
p.DeepSeek.APIKey == "" && p.DeepSeek.APIBase == "" &&
|
||||
p.Cerebras.APIKey == "" && p.Cerebras.APIBase == "" &&
|
||||
p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" &&
|
||||
p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" &&
|
||||
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
|
||||
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
|
||||
@@ -595,6 +630,7 @@ type ExecConfig struct {
|
||||
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
|
||||
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
|
||||
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
|
||||
TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s)
|
||||
}
|
||||
|
||||
type SkillsToolsConfig struct {
|
||||
@@ -627,6 +663,7 @@ type ToolsConfig struct {
|
||||
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
|
||||
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
|
||||
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
|
||||
@@ -900,6 +937,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
return t.Subagent.Enabled
|
||||
case "web_fetch":
|
||||
return t.WebFetch.Enabled
|
||||
case "send_file":
|
||||
return t.SendFile.Enabled
|
||||
case "write_file":
|
||||
return t.WriteFile.Enabled
|
||||
case "mcp":
|
||||
|
||||
@@ -261,6 +261,14 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Vivgrid - https://vivgrid.com
|
||||
{
|
||||
ModelName: "vivgrid-auto",
|
||||
Model: "vivgrid/auto",
|
||||
APIBase: "https://api.vivgrid.com/v1",
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Volcengine (火山引擎) - https://console.volcengine.com/ark
|
||||
{
|
||||
ModelName: "doubao-pro",
|
||||
@@ -386,6 +394,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
},
|
||||
EnableDenyPatterns: true,
|
||||
TimeoutSeconds: 60,
|
||||
},
|
||||
Skills: SkillsToolsConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
@@ -403,6 +412,9 @@ func DefaultConfig() *Config {
|
||||
TTLSeconds: 300,
|
||||
},
|
||||
},
|
||||
SendFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: false,
|
||||
|
||||
@@ -292,6 +292,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"vivgrid"},
|
||||
protocol: "vivgrid",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "vivgrid",
|
||||
Model: "vivgrid/auto",
|
||||
APIKey: p.Vivgrid.APIKey,
|
||||
APIBase: p.Vivgrid.APIBase,
|
||||
Proxy: p.Vivgrid.Proxy,
|
||||
RequestTimeout: p.Vivgrid.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"volcengine", "doubao"},
|
||||
protocol: "volcengine",
|
||||
|
||||
@@ -155,7 +155,8 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
ShengSuanYun: ProviderConfig{APIKey: "key11"},
|
||||
DeepSeek: ProviderConfig{APIKey: "key12"},
|
||||
Cerebras: ProviderConfig{APIKey: "key13"},
|
||||
VolcEngine: ProviderConfig{APIKey: "key14"},
|
||||
Vivgrid: ProviderConfig{APIKey: "key14"},
|
||||
VolcEngine: ProviderConfig{APIKey: "key15"},
|
||||
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
|
||||
Antigravity: ProviderConfig{AuthMethod: "oauth"},
|
||||
Qwen: ProviderConfig{APIKey: "key17"},
|
||||
@@ -166,9 +167,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
// All 20 providers should be converted
|
||||
if len(result) != 20 {
|
||||
t.Errorf("len(result) = %d, want 20", len(result))
|
||||
// All 21 providers should be converted
|
||||
if len(result) != 21 {
|
||||
t.Errorf("len(result) = %d, want 21", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -190,14 +190,21 @@ func (cs *CronService) executeJobByID(jobID string) {
|
||||
cs.mu.RUnlock()
|
||||
|
||||
if callbackJob == nil {
|
||||
log.Printf("[cron] job %s not found, skipping", jobID)
|
||||
return
|
||||
}
|
||||
|
||||
// Log job execution start
|
||||
log.Printf("[cron] ▶ executing job '%s' (id: %s, schedule: %s, channel: %s)",
|
||||
callbackJob.Name, jobID, callbackJob.Schedule.Kind, callbackJob.Payload.Channel)
|
||||
|
||||
var err error
|
||||
if cs.onJob != nil {
|
||||
_, err = cs.onJob(callbackJob)
|
||||
}
|
||||
|
||||
execDuration := time.Now().UnixMilli() - startTime
|
||||
|
||||
// Now acquire lock to update state
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
@@ -220,22 +227,35 @@ func (cs *CronService) executeJobByID(jobID string) {
|
||||
if err != nil {
|
||||
job.State.LastStatus = "error"
|
||||
job.State.LastError = err.Error()
|
||||
log.Printf("[cron] ✗ job '%s' failed after %dms: %v", job.Name, execDuration, err)
|
||||
} else {
|
||||
job.State.LastStatus = "ok"
|
||||
job.State.LastError = ""
|
||||
}
|
||||
|
||||
// Compute next run time
|
||||
var nextRunStr string
|
||||
if job.Schedule.Kind == "at" {
|
||||
if job.DeleteAfterRun {
|
||||
cs.removeJobUnsafe(job.ID)
|
||||
nextRunStr = "(deleted)"
|
||||
} else {
|
||||
job.Enabled = false
|
||||
job.State.NextRunAtMS = nil
|
||||
nextRunStr = "(disabled)"
|
||||
}
|
||||
} else {
|
||||
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
|
||||
job.State.NextRunAtMS = nextRun
|
||||
if nextRun != nil {
|
||||
nextRunStr = time.UnixMilli(*nextRun).Format("2006-01-02 15:04:05")
|
||||
} else {
|
||||
nextRunStr = "(none)"
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
log.Printf("[cron] ✓ job '%s' completed in %dms, next run: %s", job.Name, execDuration, nextRunStr)
|
||||
}
|
||||
|
||||
if err := cs.saveStoreUnsafe(); err != nil {
|
||||
|
||||
@@ -23,7 +23,10 @@ type (
|
||||
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
)
|
||||
|
||||
const defaultBaseURL = "https://api.anthropic.com"
|
||||
const (
|
||||
defaultBaseURL = "https://api.anthropic.com"
|
||||
anthropicBetaHeader = "oauth-2025-04-20"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
client *anthropic.Client
|
||||
@@ -80,7 +83,10 @@ func (p *Provider) Chat(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
opts = append(opts, option.WithAuthToken(tok))
|
||||
opts = append(opts,
|
||||
option.WithAuthToken(tok),
|
||||
option.WithHeader("anthropic-beta", anthropicBetaHeader),
|
||||
)
|
||||
}
|
||||
|
||||
params, err := buildParams(messages, tools, model, options)
|
||||
@@ -88,6 +94,11 @@ func (p *Provider) Chat(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// OAuth/setup-tokens require streaming; API keys use non-streaming.
|
||||
if p.tokenSource != nil {
|
||||
return p.chatStreaming(ctx, params, opts)
|
||||
}
|
||||
|
||||
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude API call: %w", err)
|
||||
@@ -96,6 +107,28 @@ func (p *Provider) Chat(
|
||||
return parseResponse(resp), nil
|
||||
}
|
||||
|
||||
func (p *Provider) chatStreaming(
|
||||
ctx context.Context,
|
||||
params anthropic.MessageNewParams,
|
||||
opts []option.RequestOption,
|
||||
) (*LLMResponse, error) {
|
||||
stream := p.client.Messages.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
|
||||
var msg anthropic.Message
|
||||
for stream.Next() {
|
||||
event := stream.Current()
|
||||
if err := msg.Accumulate(event); err != nil {
|
||||
return nil, fmt.Errorf("claude streaming accumulate: %w", err)
|
||||
}
|
||||
}
|
||||
if err := stream.Err(); err != nil {
|
||||
return nil, fmt.Errorf("claude API call: %w", err)
|
||||
}
|
||||
|
||||
return parseResponse(&msg), nil
|
||||
}
|
||||
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return "claude-sonnet-4.6"
|
||||
}
|
||||
@@ -147,7 +180,16 @@ func buildParams(
|
||||
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||
args := tc.Arguments
|
||||
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, args, tc.Name))
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
} else {
|
||||
@@ -167,8 +209,12 @@ func buildParams(
|
||||
maxTokens = int64(mt)
|
||||
}
|
||||
|
||||
// Normalize model ID: Anthropic API uses hyphens (claude-sonnet-4-6),
|
||||
// but config may use dots (claude-sonnet-4.6).
|
||||
apiModel := strings.ReplaceAll(model, ".", "-")
|
||||
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(model),
|
||||
Model: anthropic.Model(apiModel),
|
||||
Messages: anthropicMessages,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
@@ -21,8 +21,8 @@ func TestBuildParams_BasicMessage(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if string(params.Model) != "claude-sonnet-4.6" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4.6")
|
||||
if string(params.Model) != "claude-sonnet-4-6" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-6")
|
||||
}
|
||||
if params.MaxTokens != 1024 {
|
||||
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||
@@ -262,6 +262,65 @@ func TestProvider_ChatUsesTokenSource(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ChatStreamingRoundTrip(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
|
||||
t.Errorf("Authorization = %q, want %q", got, "Bearer refreshed-token")
|
||||
}
|
||||
if got := r.Header.Get("Anthropic-Beta"); got != anthropicBetaHeader {
|
||||
t.Errorf("Anthropic-Beta = %q, want %q", got, anthropicBetaHeader)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
events := []string{
|
||||
"event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":12,\"output_tokens\":0}}}\n\n",
|
||||
"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
|
||||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n",
|
||||
"event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\n",
|
||||
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n",
|
||||
}
|
||||
for _, e := range events {
|
||||
w.Write([]byte(e))
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
|
||||
return "refreshed-token", nil
|
||||
}, server.URL)
|
||||
|
||||
resp, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "Hello"}},
|
||||
nil,
|
||||
"claude-sonnet-4.6",
|
||||
map[string]any{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello world" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello world")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 5 {
|
||||
t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
|
||||
c := anthropic.NewClient(
|
||||
anthropicoption.WithAuthToken(token),
|
||||
|
||||
@@ -153,6 +153,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
}
|
||||
case "vivgrid":
|
||||
if cfg.Providers.Vivgrid.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Vivgrid.APIKey
|
||||
sel.apiBase = cfg.Providers.Vivgrid.APIBase
|
||||
sel.proxy = cfg.Providers.Vivgrid.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.vivgrid.com/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claude-code", "claudecode":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
@@ -295,6 +304,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "vivgrid/") && cfg.Providers.Vivgrid.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Vivgrid.APIKey
|
||||
sel.apiBase = cfg.Providers.Vivgrid.APIBase
|
||||
sel.proxy = cfg.Providers.Vivgrid.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.vivgrid.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Ollama.APIKey
|
||||
sel.apiBase = cfg.Providers.Ollama.APIBase
|
||||
|
||||
@@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"volcengine", "vllm", "qwen", "mistral", "avian":
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -200,6 +200,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://api.deepseek.com/v1"
|
||||
case "cerebras":
|
||||
return "https://api.cerebras.ai/v1"
|
||||
case "vivgrid":
|
||||
return "https://api.vivgrid.com/v1"
|
||||
case "volcengine":
|
||||
return "https://ark.cn-beijing.volces.com/api/v3"
|
||||
case "qwen":
|
||||
|
||||
@@ -108,6 +108,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
{"groq", "groq"},
|
||||
{"openrouter", "openrouter"},
|
||||
{"cerebras", "cerebras"},
|
||||
{"vivgrid", "vivgrid"},
|
||||
{"qwen", "qwen"},
|
||||
{"vllm", "vllm"},
|
||||
{"deepseek", "deepseek"},
|
||||
|
||||
@@ -88,6 +88,17 @@ func TestResolveProviderSelection(t *testing.T) {
|
||||
wantAPIBase: "https://integrate.api.nvidia.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit vivgrid provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "vivgrid"
|
||||
cfg.Providers.Vivgrid.APIKey = "vivgrid-key"
|
||||
cfg.Providers.Vivgrid.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.vivgrid.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "openrouter model uses openrouter defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
@@ -183,19 +184,94 @@ func (p *Provider) Chat(
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
|
||||
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if looksLikeHTML(body, contentType) {
|
||||
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"API request failed:\n Status: %d\n Body: %s",
|
||||
resp.StatusCode,
|
||||
responsePreview(body, 128),
|
||||
)
|
||||
}
|
||||
|
||||
return parseResponse(body)
|
||||
// Peek without consuming so the full stream reaches the JSON decoder.
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
|
||||
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
|
||||
return nil, fmt.Errorf("failed to inspect response: %w", err)
|
||||
}
|
||||
if looksLikeHTML(prefix, contentType) {
|
||||
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
|
||||
}
|
||||
|
||||
out, err := parseResponse(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func parseResponse(body []byte) (*LLMResponse, error) {
|
||||
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
|
||||
respPreview := responsePreview(body, 128)
|
||||
return fmt.Errorf(
|
||||
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
|
||||
apiBase,
|
||||
contentType,
|
||||
statusCode,
|
||||
respPreview,
|
||||
)
|
||||
}
|
||||
|
||||
func looksLikeHTML(body []byte, contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
|
||||
return true
|
||||
}
|
||||
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
|
||||
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<head")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<body"))
|
||||
}
|
||||
|
||||
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
|
||||
i := 0
|
||||
for i < len(body) {
|
||||
switch body[i] {
|
||||
case ' ', '\t', '\n', '\r', '\f', '\v':
|
||||
i++
|
||||
default:
|
||||
end := i + maxLen
|
||||
if end > len(body) {
|
||||
end = len(body)
|
||||
}
|
||||
return body[i:end]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func responsePreview(body []byte, maxLen int) string {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(trimmed) <= maxLen {
|
||||
return string(trimmed)
|
||||
}
|
||||
return string(trimmed[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func parseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
@@ -222,8 +298,8 @@ func parseResponse(body []byte) (*LLMResponse, error) {
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
@@ -363,7 +439,8 @@ func normalizeModel(model, apiBase string) string {
|
||||
|
||||
prefix := strings.ToLower(before)
|
||||
switch prefix {
|
||||
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
|
||||
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google",
|
||||
"openrouter", "zhipu", "mistral", "vivgrid":
|
||||
return after
|
||||
default:
|
||||
return model
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -212,6 +215,132 @@ func TestProviderChat_HTTPError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_JSONHTTPErrorDoesNotReportHTML(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"bad request"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Status: 400") {
|
||||
t.Fatalf("expected status code in error, got %v", err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "returned HTML instead of JSON") {
|
||||
t.Fatalf("expected non-HTML http error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_HTMLResponsesReturnHelpfulError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contentType string
|
||||
statusCode int
|
||||
body string
|
||||
}{
|
||||
{
|
||||
name: "html success response",
|
||||
contentType: "text/html; charset=utf-8",
|
||||
statusCode: http.StatusOK,
|
||||
body: "<!DOCTYPE html><html><body>gateway login</body></html>",
|
||||
},
|
||||
{
|
||||
name: "html error response",
|
||||
contentType: "text/html; charset=utf-8",
|
||||
statusCode: http.StatusBadGateway,
|
||||
body: "<!DOCTYPE html><html><body>bad gateway</body></html>",
|
||||
},
|
||||
{
|
||||
name: "mislabeled html success response",
|
||||
contentType: "application/json",
|
||||
statusCode: http.StatusOK,
|
||||
body: " \r\n\t<!DOCTYPE html><html><body>gateway login</body></html>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", tt.contentType)
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, _ = w.Write([]byte(tt.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), fmt.Sprintf("Status: %d", tt.statusCode)) {
|
||||
t.Fatalf("expected status code in error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "returned HTML instead of JSON") {
|
||||
t.Fatalf("expected helpful HTML error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "check api_base or proxy configuration") {
|
||||
t.Fatalf("expected configuration hint, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_SuccessResponseUsesStreamingDecoder(t *testing.T) {
|
||||
content := strings.Repeat("a", 1024)
|
||||
body := `{"choices":[{"message":{"content":"` + content + `"},"finish_reason":"stop"}]}`
|
||||
|
||||
p := NewProvider("key", "https://example.com/v1", "")
|
||||
p.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: &errAfterDataReadCloser{
|
||||
data: []byte(body),
|
||||
chunkSize: 64,
|
||||
},
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if out.Content != content {
|
||||
t.Fatalf("Content = %q, want %q", out.Content, content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_LargeHTMLResponsePreviewIsTruncated(t *testing.T) {
|
||||
body := append([]byte("<!DOCTYPE html><html><body>"), bytes.Repeat([]byte("A"), 2048)...)
|
||||
body = append(body, []byte("</body></html>")...)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = w.Write(body)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "Body: <!DOCTYPE html><html><body>") {
|
||||
t.Fatalf("expected html preview in error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "...") {
|
||||
t.Fatalf("expected truncated preview, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
@@ -253,7 +382,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
@@ -279,6 +408,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
input: "deepseek/deepseek-chat",
|
||||
wantModel: "deepseek-chat",
|
||||
},
|
||||
{
|
||||
name: "strips vivgrid prefix",
|
||||
input: "vivgrid/auto",
|
||||
wantModel: "auto",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -383,6 +517,12 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
|
||||
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
|
||||
}
|
||||
if got := normalizeModel("vivgrid/managed", "https://api.vivgrid.com/v1"); got != "managed" {
|
||||
t.Fatalf("normalizeModel(vivgrid) = %q, want %q", got, "managed")
|
||||
}
|
||||
if got := normalizeModel("vivgrid/auto", "https://api.vivgrid.com/v1"); got != "auto" {
|
||||
t.Fatalf("normalizeModel(vivgrid auto) = %q, want %q", got, "auto")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_RequestTimeoutDefault(t *testing.T) {
|
||||
@@ -399,6 +539,40 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
|
||||
type errAfterDataReadCloser struct {
|
||||
data []byte
|
||||
chunkSize int
|
||||
offset int
|
||||
}
|
||||
|
||||
func (r *errAfterDataReadCloser) Read(p []byte) (int, error) {
|
||||
if r.offset >= len(r.data) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
n := r.chunkSize
|
||||
if n <= 0 || n > len(p) {
|
||||
n = len(p)
|
||||
}
|
||||
remaining := len(r.data) - r.offset
|
||||
if n > remaining {
|
||||
n = remaining
|
||||
}
|
||||
copy(p, r.data[r.offset:r.offset+n])
|
||||
r.offset += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *errAfterDataReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) {
|
||||
p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens"))
|
||||
if p.maxTokensField != "max_completion_tokens" {
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package routing
|
||||
|
||||
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
|
||||
// A higher score indicates a more complex task that benefits from a heavy model.
|
||||
// The score is compared against the configured threshold: score >= threshold selects
|
||||
// the primary (heavy) model; score < threshold selects the light model.
|
||||
//
|
||||
// Classifier is an interface so that future implementations (ML-based, embedding-based,
|
||||
// or any other approach) can be swapped in without changing routing infrastructure.
|
||||
type Classifier interface {
|
||||
Score(f Features) float64
|
||||
}
|
||||
|
||||
// RuleClassifier is the v1 implementation.
|
||||
// It uses a weighted sum of structural signals with no external dependencies,
|
||||
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
|
||||
// that the returned score always falls within the [0, 1] contract.
|
||||
//
|
||||
// Individual weights (multiple signals can fire simultaneously):
|
||||
//
|
||||
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
|
||||
// token 50-200: 0.15 — medium length; may or may not be complex
|
||||
// code block present: 0.40 — coding tasks need the heavy model
|
||||
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
|
||||
// tool calls 1-3 (recent): 0.10 — some tool activity
|
||||
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
|
||||
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
|
||||
//
|
||||
// Default threshold is 0.35, so:
|
||||
// - Pure greetings / trivial Q&A: 0.00 → light ✓
|
||||
// - Medium prose message (50–200 tokens): 0.15 → light ✓
|
||||
// - Message with code block: 0.40 → heavy ✓
|
||||
// - Long message (>200 tokens): 0.35 → heavy ✓
|
||||
// - Active tool session + medium message: 0.25 → light (acceptable)
|
||||
// - Any message with an image/audio attachment: 1.00 → heavy ✓
|
||||
type RuleClassifier struct{}
|
||||
|
||||
// Score computes the complexity score for the given feature set.
|
||||
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
|
||||
func (c *RuleClassifier) Score(f Features) float64 {
|
||||
// Hard gate: multi-modal inputs always require the heavy model.
|
||||
if f.HasAttachments {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
var score float64
|
||||
|
||||
// Token estimate — primary verbosity signal
|
||||
switch {
|
||||
case f.TokenEstimate > 200:
|
||||
score += 0.35
|
||||
case f.TokenEstimate > 50:
|
||||
score += 0.15
|
||||
}
|
||||
|
||||
// Fenced code blocks — strongest indicator of a coding/technical task
|
||||
if f.CodeBlockCount > 0 {
|
||||
score += 0.40
|
||||
}
|
||||
|
||||
// Recent tool call density — indicates an ongoing agentic workflow
|
||||
switch {
|
||||
case f.RecentToolCalls > 3:
|
||||
score += 0.25
|
||||
case f.RecentToolCalls > 0:
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Conversation depth — accumulated context implies compound task
|
||||
if f.ConversationDepth > 10 {
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
|
||||
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
|
||||
if score > 1.0 {
|
||||
score = 1.0
|
||||
}
|
||||
return score
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// lookbackWindow is the number of recent history entries scanned for tool calls.
|
||||
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
|
||||
const lookbackWindow = 6
|
||||
|
||||
// Features holds the structural signals extracted from a message and its session context.
|
||||
// Every dimension is language-agnostic by construction — no keyword or pattern matching
|
||||
// against natural-language content. This ensures consistent routing for all locales.
|
||||
type Features struct {
|
||||
// TokenEstimate is a proxy for token count.
|
||||
// CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each.
|
||||
// This avoids API calls while giving accurate estimates for all scripts.
|
||||
TokenEstimate int
|
||||
|
||||
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
|
||||
// Coding tasks almost always require the heavy model.
|
||||
CodeBlockCount int
|
||||
|
||||
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
|
||||
// history entries. A high density indicates an active agentic workflow.
|
||||
RecentToolCalls int
|
||||
|
||||
// ConversationDepth is the total number of messages in the session history.
|
||||
// Deep sessions tend to carry implicit complexity built up over many turns.
|
||||
ConversationDepth int
|
||||
|
||||
// HasAttachments is true when the message appears to contain media (images,
|
||||
// audio, video). Multi-modal inputs require vision-capable heavy models.
|
||||
HasAttachments bool
|
||||
}
|
||||
|
||||
// ExtractFeatures computes the structural feature vector for a message.
|
||||
// It is a pure function with no side effects and zero allocations beyond
|
||||
// the returned struct.
|
||||
func ExtractFeatures(msg string, history []providers.Message) Features {
|
||||
return Features{
|
||||
TokenEstimate: estimateTokens(msg),
|
||||
CodeBlockCount: countCodeBlocks(msg),
|
||||
RecentToolCalls: countRecentToolCalls(history),
|
||||
ConversationDepth: len(history),
|
||||
HasAttachments: hasAttachments(msg),
|
||||
}
|
||||
}
|
||||
|
||||
// estimateTokens returns a token count proxy that handles both CJK and Latin text.
|
||||
// CJK runes (U+2E80–U+9FFF, U+F900–U+FAFF, U+AC00–U+D7AF) map to roughly one
|
||||
// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token
|
||||
// for English). Splitting the count this way avoids the 3x underestimation that a
|
||||
// flat rune_count/3 would produce for Chinese, Japanese, and Korean text.
|
||||
func estimateTokens(msg string) int {
|
||||
total := utf8.RuneCountInString(msg)
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
cjk := 0
|
||||
for _, r := range msg {
|
||||
if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF {
|
||||
cjk++
|
||||
}
|
||||
}
|
||||
return cjk + (total-cjk)/4
|
||||
}
|
||||
|
||||
// countCodeBlocks counts the number of complete fenced code blocks.
|
||||
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
|
||||
// An unclosed opening fence (odd count) is treated as zero complete blocks
|
||||
// since it may just be an inline code span or a typo.
|
||||
func countCodeBlocks(msg string) int {
|
||||
n := strings.Count(msg, "```")
|
||||
return n / 2
|
||||
}
|
||||
|
||||
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
|
||||
// entries of history. It examines the ToolCalls field rather than parsing
|
||||
// the content string, so it is robust to any message format.
|
||||
func countRecentToolCalls(history []providers.Message) int {
|
||||
start := len(history) - lookbackWindow
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, msg := range history[start:] {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
count += len(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// hasAttachments returns true when the message content contains embedded media.
|
||||
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
|
||||
// common image/audio URL extensions. This is intentionally conservative —
|
||||
// false negatives (missing an attachment) just mean the routing falls back to
|
||||
// the primary model anyway.
|
||||
func hasAttachments(msg string) bool {
|
||||
lower := strings.ToLower(msg)
|
||||
|
||||
// Base64 data URIs embedded directly in the message
|
||||
if strings.Contains(lower, "data:image/") ||
|
||||
strings.Contains(lower, "data:audio/") ||
|
||||
strings.Contains(lower, "data:video/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Common image/audio extensions in URLs or file references
|
||||
mediaExts := []string{
|
||||
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
|
||||
".mp3", ".wav", ".ogg", ".m4a", ".flac",
|
||||
".mp4", ".avi", ".mov", ".webm",
|
||||
}
|
||||
for _, ext := range mediaExts {
|
||||
if strings.Contains(lower, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// defaultThreshold is used when the config threshold is zero or negative.
|
||||
// At 0.35 a message needs at least one strong signal (code block, long text,
|
||||
// or an attachment) before the heavy model is chosen.
|
||||
const defaultThreshold = 0.35
|
||||
|
||||
// RouterConfig holds the validated model routing settings.
|
||||
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
|
||||
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
|
||||
type RouterConfig struct {
|
||||
// LightModel is the model_name (from model_list) used for simple tasks.
|
||||
LightModel string
|
||||
|
||||
// Threshold is the complexity score cutoff in [0, 1].
|
||||
// score >= Threshold → primary (heavy) model.
|
||||
// score < Threshold → light model.
|
||||
Threshold float64
|
||||
}
|
||||
|
||||
// Router selects the appropriate model tier for each incoming message.
|
||||
// It is safe for concurrent use from multiple goroutines.
|
||||
type Router struct {
|
||||
cfg RouterConfig
|
||||
classifier Classifier
|
||||
}
|
||||
|
||||
// New creates a Router with the given config and the default RuleClassifier.
|
||||
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
|
||||
func New(cfg RouterConfig) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{
|
||||
cfg: cfg,
|
||||
classifier: &RuleClassifier{},
|
||||
}
|
||||
}
|
||||
|
||||
// newWithClassifier creates a Router with a custom Classifier.
|
||||
// Intended for unit tests that need to inject a deterministic scorer.
|
||||
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{cfg: cfg, classifier: c}
|
||||
}
|
||||
|
||||
// SelectModel returns the model to use for this conversation turn along with
|
||||
// the computed complexity score (for logging and debugging).
|
||||
//
|
||||
// - If score < cfg.Threshold: returns (cfg.LightModel, true, score)
|
||||
// - Otherwise: returns (primaryModel, false, score)
|
||||
//
|
||||
// The caller is responsible for resolving the returned model name into
|
||||
// provider candidates (see AgentInstance.LightCandidates).
|
||||
func (r *Router) SelectModel(
|
||||
msg string,
|
||||
history []providers.Message,
|
||||
primaryModel string,
|
||||
) (model string, usedLight bool, score float64) {
|
||||
features := ExtractFeatures(msg, history)
|
||||
score = r.classifier.Score(features)
|
||||
if score < r.cfg.Threshold {
|
||||
return r.cfg.LightModel, true, score
|
||||
}
|
||||
return primaryModel, false, score
|
||||
}
|
||||
|
||||
// LightModel returns the configured light model name.
|
||||
func (r *Router) LightModel() string {
|
||||
return r.cfg.LightModel
|
||||
}
|
||||
|
||||
// Threshold returns the complexity threshold in use.
|
||||
func (r *Router) Threshold() float64 {
|
||||
return r.cfg.Threshold
|
||||
}
|
||||
@@ -0,0 +1,414 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ── ExtractFeatures ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractFeatures_EmptyMessage(t *testing.T) {
|
||||
f := ExtractFeatures("", nil)
|
||||
if f.TokenEstimate != 0 {
|
||||
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
|
||||
}
|
||||
if f.CodeBlockCount != 0 {
|
||||
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
|
||||
}
|
||||
if f.RecentToolCalls != 0 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
|
||||
}
|
||||
if f.ConversationDepth != 0 {
|
||||
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
|
||||
}
|
||||
if f.HasAttachments {
|
||||
t.Error("HasAttachments: got true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate(t *testing.T) {
|
||||
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
|
||||
msg := strings.Repeat("a", 30)
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 7 {
|
||||
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
|
||||
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
|
||||
// Using a rune slice literal avoids CJK string literals in source.
|
||||
msg := string([]rune{
|
||||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||||
0x4F60,
|
||||
})
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 9 {
|
||||
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
|
||||
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
|
||||
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 6 {
|
||||
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_CodeBlocks(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{"no code here", 0},
|
||||
{"```go\nfmt.Println()\n```", 1},
|
||||
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
|
||||
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.CodeBlockCount != tc.want {
|
||||
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
|
||||
// History longer than lookbackWindow — only last lookbackWindow entries count.
|
||||
history := make([]providers.Message, 10)
|
||||
// Put 2 tool calls at positions 8 and 9 (within the last 6)
|
||||
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
|
||||
history[9] = providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
|
||||
}
|
||||
// Position 3 is outside the lookback window and must NOT be counted
|
||||
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
|
||||
|
||||
f := ExtractFeatures("test", history)
|
||||
// 1 (position 8) + 2 (position 9) = 3
|
||||
if f.RecentToolCalls != 3 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_ConversationDepth(t *testing.T) {
|
||||
history := make([]providers.Message, 7)
|
||||
f := ExtractFeatures("msg", history)
|
||||
if f.ConversationDepth != 7 {
|
||||
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"plain text", false},
|
||||
{"here is an image: data:image/png;base64,abc123", true},
|
||||
{"audio: data:audio/mp3;base64,xyz", true},
|
||||
{"video: data:video/mp4;base64,xyz", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"check out photo.jpg", true},
|
||||
{"see screenshot.png", true},
|
||||
{"listen to audio.mp3", true},
|
||||
{"watch clip.mp4", true},
|
||||
{"just a .go file", false},
|
||||
{"document.pdf", false}, // pdf is not in the media list
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── RuleClassifier ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{})
|
||||
if score != 0.0 {
|
||||
t.Errorf("zero features: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{HasAttachments: true})
|
||||
if score != 1.0 {
|
||||
t.Errorf("attachments: got %f, want 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Code block alone = 0.40, above default threshold 0.35
|
||||
score := c.Score(Features{CodeBlockCount: 1})
|
||||
if score < 0.35 {
|
||||
t.Errorf("code block: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_LongMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// >200 tokens = 0.35, exactly at default threshold → heavy
|
||||
score := c.Score(Features{TokenEstimate: 250})
|
||||
if score < 0.35 {
|
||||
t.Errorf("long message: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_MediumMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// 50-200 tokens = 0.15, below threshold → light
|
||||
score := c.Score(Features{TokenEstimate: 100})
|
||||
if score >= 0.35 {
|
||||
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ShortMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// <50 tokens, no other signals = 0.0 → light
|
||||
score := c.Score(Features{TokenEstimate: 10})
|
||||
if score != 0.0 {
|
||||
t.Errorf("short message: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
|
||||
scoreNone := c.Score(Features{RecentToolCalls: 0})
|
||||
scoreLow := c.Score(Features{RecentToolCalls: 2})
|
||||
scoreHigh := c.Score(Features{RecentToolCalls: 5})
|
||||
|
||||
if scoreNone != 0.0 {
|
||||
t.Errorf("no tools: got %f, want 0.0", scoreNone)
|
||||
}
|
||||
if scoreLow <= scoreNone {
|
||||
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
|
||||
}
|
||||
if scoreHigh <= scoreLow {
|
||||
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_DeepConversation(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
shallow := c.Score(Features{ConversationDepth: 5})
|
||||
deep := c.Score(Features{ConversationDepth: 15})
|
||||
if deep <= shallow {
|
||||
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Max all signals simultaneously
|
||||
f := Features{
|
||||
TokenEstimate: 500,
|
||||
CodeBlockCount: 3,
|
||||
RecentToolCalls: 10,
|
||||
ConversationDepth: 20,
|
||||
}
|
||||
score := c.Score(f)
|
||||
if score > 1.0 {
|
||||
t.Errorf("score %f exceeds 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Router ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestRouter_DefaultThreshold(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash"})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "hi"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("simple message: expected light model to be selected")
|
||||
}
|
||||
if model != "gemini-flash" {
|
||||
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "```go\nfmt.Println(\"hello\")\n```"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("code block: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "can you analyze this? data:image/png;base64,abc123"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("attachment: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
// >200 token estimate: 210 * 3 = 630 chars
|
||||
msg := strings.Repeat("word ", 210)
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("long message: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
|
||||
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
|
||||
// Routing is conservative: only promote to heavy when the signal is unambiguous.
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
|
||||
}
|
||||
msg := "ok"
|
||||
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
|
||||
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{
|
||||
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
|
||||
}},
|
||||
}
|
||||
// ~55 tokens * 3 = 165 chars
|
||||
msg := strings.Repeat("word ", 55)
|
||||
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
|
||||
// Very low threshold: even a short message triggers heavy model
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
|
||||
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
|
||||
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("low threshold: medium message should use primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
|
||||
// Very high threshold: even code blocks route to light
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
|
||||
msg := "```go\nfmt.Println()\n```"
|
||||
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("very high threshold: code block (0.40) should route to light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_LightModel(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
|
||||
if r.LightModel() != "my-fast-model" {
|
||||
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
|
||||
}
|
||||
}
|
||||
|
||||
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
|
||||
|
||||
type fixedScoreClassifier struct{ score float64 }
|
||||
|
||||
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
|
||||
|
||||
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.2},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if !usedLight {
|
||||
t.Error("low score with custom classifier: expected light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.8},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("high score with custom classifier: expected primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
|
||||
// score == threshold → primary (uses >= comparison)
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.5},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("score == threshold: expected primary model (>= threshold → primary)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.42},
|
||||
)
|
||||
_, _, score := r.SelectModel("anything", nil, "heavy")
|
||||
if score != 0.42 {
|
||||
t.Errorf("score: got %f, want 0.42", score)
|
||||
}
|
||||
}
|
||||
@@ -141,6 +141,12 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
everySeconds, hasEvery := args["every_seconds"].(float64)
|
||||
cronExpr, hasCron := args["cron_expr"].(string)
|
||||
|
||||
// Fix: type assertions return true for zero values, need additional validity checks
|
||||
// This prevents LLMs that fill unused optional parameters with defaults (0) from triggering wrong type
|
||||
hasAt = hasAt && atSeconds > 0
|
||||
hasEvery = hasEvery && everySeconds > 0
|
||||
hasCron = hasCron && cronExpr != ""
|
||||
|
||||
// Priority: at_seconds > every_seconds > cron_expr
|
||||
if hasAt {
|
||||
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// SendFileTool allows the LLM to send a local file (image, document, etc.)
|
||||
// to the user on the current chat channel via the MediaStore pipeline.
|
||||
type SendFileTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
}
|
||||
|
||||
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
return &SendFileTool{
|
||||
workspace: workspace,
|
||||
restrict: restrict,
|
||||
maxFileSize: maxFileSize,
|
||||
mediaStore: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SendFileTool) Name() string { return "send_file" }
|
||||
func (t *SendFileTool) Description() string {
|
||||
return "Send a local file (image, document, etc.) to the user on the current chat channel."
|
||||
}
|
||||
|
||||
func (t *SendFileTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the local file. Relative paths are resolved from workspace.",
|
||||
},
|
||||
"filename": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional display filename. Defaults to the basename of path.",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SendFileTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SendFileTool) SetMediaStore(store media.MediaStore) {
|
||||
t.mediaStore = store
|
||||
}
|
||||
|
||||
func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
path, _ := args["path"].(string)
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
// Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values.
|
||||
channel := ToolChannel(ctx)
|
||||
if channel == "" {
|
||||
channel = t.defaultChannel
|
||||
}
|
||||
chatID := ToolChatID(ctx)
|
||||
if chatID == "" {
|
||||
chatID = t.defaultChatID
|
||||
}
|
||||
if channel == "" || chatID == "" {
|
||||
return ErrorResult("no target channel/chat available")
|
||||
}
|
||||
|
||||
if t.mediaStore == nil {
|
||||
return ErrorResult("media store not configured")
|
||||
}
|
||||
|
||||
resolved, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
|
||||
}
|
||||
|
||||
info, err := os.Stat(resolved)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("file not found: %v", err))
|
||||
}
|
||||
if info.IsDir() {
|
||||
return ErrorResult("path is a directory, expected a file")
|
||||
}
|
||||
if info.Size() > int64(t.maxFileSize) {
|
||||
return ErrorResult(fmt.Sprintf(
|
||||
"file too large: %d bytes (max %d bytes)",
|
||||
info.Size(), t.maxFileSize,
|
||||
))
|
||||
}
|
||||
|
||||
filename, _ := args["filename"].(string)
|
||||
if filename == "" {
|
||||
filename = filepath.Base(resolved)
|
||||
}
|
||||
|
||||
mediaType := detectMediaType(resolved)
|
||||
scope := fmt.Sprintf("tool:send_file:%s:%s", channel, chatID)
|
||||
|
||||
ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
|
||||
Filename: filename,
|
||||
ContentType: mediaType,
|
||||
Source: "tool:send_file",
|
||||
}, scope)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to register media: %v", err))
|
||||
}
|
||||
|
||||
return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref})
|
||||
}
|
||||
|
||||
// detectMediaType determines the MIME type of a file.
|
||||
// Uses magic-bytes detection (h2non/filetype) first, then falls back to
|
||||
// extension-based lookup via mime.TypeByExtension.
|
||||
func detectMediaType(path string) string {
|
||||
kind, err := filetype.MatchFile(path)
|
||||
if err == nil && kind != filetype.Unknown {
|
||||
return kind.MIME.Value
|
||||
}
|
||||
|
||||
if ext := filepath.Ext(path); ext != "" {
|
||||
if t := mime.TypeByExtension(ext); t != "" {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
return "application/octet-stream"
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestSendFileTool_MissingPath(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool("/tmp", false, 0, store)
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error for missing path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_NoContext(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool("/tmp", false, 0, store)
|
||||
// no SetContext call
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp/test.txt"})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when no channel context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_NoMediaStore(t *testing.T) {
|
||||
tool := NewSendFileTool("/tmp", false, 0, nil)
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp/test.txt"})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when no media store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_Directory(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool("/tmp", false, 0, store)
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": "/tmp"})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error for directory path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_FileTooLarge(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testFile := filepath.Join(dir, "big.bin")
|
||||
// Create a file larger than the limit
|
||||
if err := os.WriteFile(testFile, make([]byte, 1024), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool(dir, false, 512, store) // 512 byte limit
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": testFile})
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error for oversized file")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "too large") {
|
||||
t.Errorf("expected 'too large' in error, got %q", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_DefaultMaxSize(t *testing.T) {
|
||||
tool := NewSendFileTool("/tmp", false, 0, nil)
|
||||
if tool.maxFileSize != config.DefaultMaxMediaSize {
|
||||
t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_Success(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testFile := filepath.Join(dir, "photo.png")
|
||||
if err := os.WriteFile(testFile, []byte("fake png"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool(dir, false, 0, store)
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": testFile})
|
||||
if result.IsError {
|
||||
t.Fatalf("unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
if result.Media[0][:8] != "media://" {
|
||||
t.Errorf("expected media:// ref, got %q", result.Media[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_CustomFilename(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
testFile := filepath.Join(dir, "img.jpg")
|
||||
if err := os.WriteFile(testFile, []byte("fake jpg"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool(dir, false, 0, store)
|
||||
tool.SetContext("telegram", "chat456")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": testFile,
|
||||
"filename": "my-photo.jpg",
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType_MagicBytes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Minimal valid PNG header
|
||||
pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
|
||||
pngFile := filepath.Join(dir, "image.dat") // wrong extension, but valid PNG bytes
|
||||
if err := os.WriteFile(pngFile, pngHeader, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := detectMediaType(pngFile)
|
||||
if got != "image/png" {
|
||||
t.Errorf("expected image/png from magic bytes, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType_FallbackToExtension(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// File with unrecognizable content but known extension
|
||||
txtFile := filepath.Join(dir, "readme.txt")
|
||||
if err := os.WriteFile(txtFile, []byte("hello world"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := detectMediaType(txtFile)
|
||||
// text/plain or similar — just verify it's not application/octet-stream
|
||||
if got == "application/octet-stream" {
|
||||
t.Errorf("expected extension-based MIME for .txt, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType_UnknownFallsToOctetStream(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// File with no extension and random bytes
|
||||
unknownFile := filepath.Join(dir, "mystery")
|
||||
if err := os.WriteFile(unknownFile, []byte{0x00, 0x01, 0x02}, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := detectMediaType(unknownFile)
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("expected application/octet-stream, got %q", got)
|
||||
}
|
||||
}
|
||||
+7
-2
@@ -59,7 +59,7 @@ var (
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bkill\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
@@ -131,9 +131,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
|
||||
timeout := 60 * time.Second
|
||||
if config != nil && config.Tools.Exec.TimeoutSeconds > 0 {
|
||||
timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second
|
||||
}
|
||||
|
||||
return &ExecTool{
|
||||
workingDir: workingDir,
|
||||
timeout: 60 * time.Second,
|
||||
timeout: timeout,
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
|
||||
@@ -151,6 +151,26 @@ func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellTool_DangerousCommand_KillBlocked(t *testing.T) {
|
||||
tool, err := NewExecTool("", false)
|
||||
if err != nil {
|
||||
t.Errorf("unable to configure exec tool: %s", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"command": "kill 12345",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected kill command to be blocked")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf("Expected blocked message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_MissingCommand verifies error handling for missing command
|
||||
func TestShellTool_MissingCommand(t *testing.T) {
|
||||
tool, err := NewExecTool("", false)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -42,7 +42,7 @@ func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
|
||||
|
||||
func TestSpawnTool_Execute_ValidTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -27,7 +26,6 @@ type SubagentManager struct {
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
defaultModel string
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
tools *ToolRegistry
|
||||
maxIterations int
|
||||
@@ -41,13 +39,11 @@ type SubagentManager struct {
|
||||
func NewSubagentManager(
|
||||
provider providers.LLMProvider,
|
||||
defaultModel, workspace string,
|
||||
bus *bus.MessageBus,
|
||||
) *SubagentManager {
|
||||
return &SubagentManager{
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
provider: provider,
|
||||
defaultModel: defaultModel,
|
||||
bus: bus,
|
||||
workspace: workspace,
|
||||
tools: NewToolRegistry(),
|
||||
maxIterations: 10,
|
||||
@@ -214,20 +210,6 @@ After completing the task, provide a clear summary of what was done.`
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Send announce message back to main agent
|
||||
if sm.bus != nil {
|
||||
announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result)
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
sm.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("subagent:%s", task.ID),
|
||||
// Format: "original_channel:original_chat_id" for routing back
|
||||
ChatID: fmt.Sprintf("%s:%s", task.OriginChannel, task.OriginChatID),
|
||||
Content: announceContent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -47,7 +46,7 @@ func (m *MockLLMProvider) GetContextWindow() int {
|
||||
|
||||
func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
manager.SetLLMOptions(2048, 0.6)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
@@ -73,7 +72,7 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
|
||||
// TestSubagentTool_Name verifies tool name
|
||||
func TestSubagentTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
if tool.Name() != "subagent" {
|
||||
@@ -84,7 +83,7 @@ func TestSubagentTool_Name(t *testing.T) {
|
||||
// TestSubagentTool_Description verifies tool description
|
||||
func TestSubagentTool_Description(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
desc := tool.Description()
|
||||
@@ -99,7 +98,7 @@ func TestSubagentTool_Description(t *testing.T) {
|
||||
// TestSubagentTool_Parameters verifies tool parameters schema
|
||||
func TestSubagentTool_Parameters(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
params := tool.Parameters()
|
||||
@@ -149,8 +148,7 @@ func TestSubagentTool_Parameters(t *testing.T) {
|
||||
// TestSubagentTool_Execute_Success tests successful execution
|
||||
func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
|
||||
@@ -204,8 +202,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
// TestSubagentTool_Execute_NoLabel tests execution without label
|
||||
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -228,7 +225,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
|
||||
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
|
||||
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -278,8 +275,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) {
|
||||
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
|
||||
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
channel := "test-channel"
|
||||
@@ -304,8 +300,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
// Create a mock provider that returns very long content
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
Reference in New Issue
Block a user