fix(line): classify SDK errors with HTTP status and add client timeout

Address review feedback:
- Use *WithHttpInfo SDK variants to get HTTP response status codes
- Map status codes via ClassifySendError (429→ErrRateLimit, 5xx→ErrTemporary, 4xx→ErrSendFailed)
- Fall back to ClassifyNetError for network-level failures
- Configure SDK with 30s timeout HTTP client

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
ex-takashima
2026-05-07 16:23:33 +09:00
225 changed files with 16494 additions and 4435 deletions
+6
View File
@@ -43,3 +43,9 @@ func (a *channelManagerAdapter) SendMedia(ctx context.Context, msg bus.OutboundM
func (a *channelManagerAdapter) SendPlaceholder(ctx context.Context, channel, chatID string) bool {
return a.inner.SendPlaceholder(ctx, channel, chatID)
}
func (a *channelManagerAdapter) DismissToolFeedback(
ctx context.Context, channel, chatID string, outboundCtx *bus.InboundContext,
) {
a.inner.DismissToolFeedback(ctx, channel, chatID, outboundCtx)
}
+40 -11
View File
@@ -21,6 +21,7 @@ import (
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -37,9 +38,13 @@ type AgentLoop struct {
registry *AgentRegistry
state *state.Manager
// Event system (from Incoming)
eventBus *EventBus
hooks *HookManager
// Runtime event system
runtimeEvents runtimeevents.Bus
ownsRuntimeEvents bool
runtimeEventLogMu sync.RWMutex
runtimeEventLogger *runtimeEventLogger
runtimeEventLogSub runtimeevents.Subscription
hooks *HookManager
// Runtime state
running atomic.Bool
@@ -53,6 +58,7 @@ type AgentLoop struct {
hookRuntime hookRuntime
steering *steeringQueue
pendingSkills sync.Map
pendingStops sync.Map
mu sync.RWMutex
// workerSem limits concurrent turn processing workers.
@@ -172,6 +178,10 @@ func (al *AgentLoop) Run(ctx context.Context) error {
phase: TurnPhaseSetup,
}
if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded {
if al.tryHandleStopCommand(ctx, msg, sessionKey) {
continue
}
// Another turn is already active (or reserved) for this session — enqueue
if err := al.enqueueSteeringMessage(sessionKey, agentID, providers.Message{
Role: "user",
@@ -235,6 +245,24 @@ func (al *AgentLoop) Run(ctx context.Context) error {
defer al.channelManager.InvokeTypingStop(m.Channel, m.ChatID)
}
if al.takePendingStop(sessionKey) {
al.activeTurnStates.Delete(sessionKey)
target := &continuationTarget{
SessionKey: sessionKey,
Channel: m.Channel,
ChatID: m.ChatID,
}
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
if continueErr != nil {
al.maybePublishError(ctx, m.Channel, m.ChatID, sessionKey, continueErr)
return
}
if continued != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, continued)
}
return
}
al.runTurnWithSteering(ctx, m)
}(msg)
@@ -285,8 +313,14 @@ func (al *AgentLoop) Close() {
if al.hooks != nil {
al.hooks.Close()
}
if al.eventBus != nil {
al.eventBus.Close()
al.closeRuntimeEventLogger()
if al.runtimeEvents != nil && al.ownsRuntimeEvents {
if err := al.runtimeEvents.Close(); err != nil {
logger.ErrorCF("agent", "Failed to close runtime event bus",
map[string]any{
"error": err.Error(),
})
}
}
}
@@ -294,12 +328,6 @@ func (al *AgentLoop) Close() {
// UnmountHook removes a previously registered in-process hook.
// SubscribeEvents registers a subscriber for agent-loop events.
// UnsubscribeEvents removes a previously registered event subscriber.
// EventDrops returns the number of dropped events for the given kind.
type turnEventScope struct {
agentID string
sessionKey string
@@ -384,6 +412,7 @@ func (al *AgentLoop) ReloadProviderAndConfig(
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), newRL)
al.mu.Unlock()
al.refreshRuntimeEventLogger(cfg)
oldMCPManager := al.mcp.reset()
al.hookRuntime.reset(al)
+6
View File
@@ -274,6 +274,12 @@ func (al *AgentLoop) buildCommandsRuntime(
return nil
},
}
rt.StopActiveTurn = func() (commands.StopResult, error) {
if opts == nil {
return commands.StopResult{}, fmt.Errorf("process options not available")
}
return al.stopActiveTurnForSession(opts.Dispatch.SessionKey)
}
if agent != nil && agent.ContextBuilder != nil {
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
}
+31 -128
View File
@@ -5,7 +5,7 @@ package agent
import (
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
func (al *AgentLoop) newTurnEventScope(agentID, sessionKey string, turnCtx *TurnContext) turnEventScope {
@@ -18,8 +18,8 @@ func (al *AgentLoop) newTurnEventScope(agentID, sessionKey string, turnCtx *Turn
}
}
func (ts turnEventScope) meta(iteration int, source, tracePath string) EventMeta {
return EventMeta{
func (ts turnEventScope) meta(iteration int, source, tracePath string) HookMeta {
return HookMeta{
AgentID: ts.agentID,
TurnID: ts.turnID,
SessionKey: ts.sessionKey,
@@ -30,119 +30,24 @@ func (ts turnEventScope) meta(iteration int, source, tracePath string) EventMeta
}
}
func (al *AgentLoop) emitEvent(kind EventKind, meta EventMeta, payload any) {
clonedMeta := cloneEventMeta(meta)
evt := Event{
Kind: kind,
Meta: clonedMeta,
Context: cloneTurnContext(clonedMeta.turnContext),
Payload: payload,
func (al *AgentLoop) emitEvent(kind runtimeevents.Kind, meta HookMeta, payload any) {
clonedMeta := cloneHookMeta(meta)
eventCtx := cloneTurnContext(clonedMeta.turnContext)
evt := runtimeevents.Event{
Kind: kind,
Source: runtimeevents.Source{Component: "agent", Name: clonedMeta.AgentID},
Scope: runtimeScopeFromHookMeta(clonedMeta, eventCtx),
Correlation: runtimeCorrelationFromHookMeta(clonedMeta),
Severity: runtimeSeverityForAgentEvent(kind, payload),
Payload: payload,
Attrs: runtimeAttrsFromHookMeta(clonedMeta),
}
if al == nil || al.eventBus == nil {
if al == nil {
return
}
al.logEvent(evt)
al.eventBus.Emit(evt)
}
func (al *AgentLoop) logEvent(evt Event) {
fields := map[string]any{
"event_kind": evt.Kind.String(),
"agent_id": evt.Meta.AgentID,
"turn_id": evt.Meta.TurnID,
"session_key": evt.Meta.SessionKey,
"iteration": evt.Meta.Iteration,
}
if evt.Meta.TracePath != "" {
fields["trace"] = evt.Meta.TracePath
}
if evt.Meta.Source != "" {
fields["source"] = evt.Meta.Source
}
appendEventContextFields(fields, evt.Context)
switch payload := evt.Payload.(type) {
case TurnStartPayload:
fields["user_len"] = len(payload.UserMessage)
fields["media_count"] = payload.MediaCount
case TurnEndPayload:
fields["status"] = payload.Status
fields["iterations_total"] = payload.Iterations
fields["duration_ms"] = payload.Duration.Milliseconds()
fields["final_len"] = payload.FinalContentLen
case LLMRequestPayload:
fields["model"] = payload.Model
fields["messages"] = payload.MessagesCount
fields["tools"] = payload.ToolsCount
fields["max_tokens"] = payload.MaxTokens
case LLMDeltaPayload:
fields["content_delta_len"] = payload.ContentDeltaLen
fields["reasoning_delta_len"] = payload.ReasoningDeltaLen
case LLMResponsePayload:
fields["content_len"] = payload.ContentLen
fields["tool_calls"] = payload.ToolCalls
fields["has_reasoning"] = payload.HasReasoning
case LLMRetryPayload:
fields["attempt"] = payload.Attempt
fields["max_retries"] = payload.MaxRetries
fields["reason"] = payload.Reason
fields["error"] = payload.Error
fields["backoff_ms"] = payload.Backoff.Milliseconds()
case ContextCompressPayload:
fields["reason"] = payload.Reason
fields["dropped_messages"] = payload.DroppedMessages
fields["remaining_messages"] = payload.RemainingMessages
case SessionSummarizePayload:
fields["summarized_messages"] = payload.SummarizedMessages
fields["kept_messages"] = payload.KeptMessages
fields["summary_len"] = payload.SummaryLen
fields["omitted_oversized"] = payload.OmittedOversized
case ToolExecStartPayload:
fields["tool"] = payload.Tool
fields["args_count"] = len(payload.Arguments)
case ToolExecEndPayload:
fields["tool"] = payload.Tool
fields["duration_ms"] = payload.Duration.Milliseconds()
fields["for_llm_len"] = payload.ForLLMLen
fields["for_user_len"] = payload.ForUserLen
fields["is_error"] = payload.IsError
fields["async"] = payload.Async
case ToolExecSkippedPayload:
fields["tool"] = payload.Tool
fields["reason"] = payload.Reason
case SteeringInjectedPayload:
fields["count"] = payload.Count
fields["total_content_len"] = payload.TotalContentLen
case FollowUpQueuedPayload:
fields["source_tool"] = payload.SourceTool
fields["content_len"] = payload.ContentLen
case InterruptReceivedPayload:
fields["interrupt_kind"] = payload.Kind
fields["role"] = payload.Role
fields["content_len"] = payload.ContentLen
fields["queue_depth"] = payload.QueueDepth
fields["hint_len"] = payload.HintLen
case SubTurnSpawnPayload:
fields["child_agent_id"] = payload.AgentID
fields["label"] = payload.Label
case SubTurnEndPayload:
fields["child_agent_id"] = payload.AgentID
fields["status"] = payload.Status
case SubTurnResultDeliveredPayload:
fields["target_channel"] = payload.TargetChannel
fields["target_chat_id"] = payload.TargetChatID
fields["content_len"] = payload.ContentLen
case ErrorPayload:
fields["stage"] = payload.Stage
fields["error"] = payload.Message
}
logger.InfoCF("eventbus", fmt.Sprintf("Agent event: %s", evt.Kind.String()), fields)
al.publishRuntimeEvent(evt)
}
// MountHook registers an in-process hook on the agent loop.
@@ -161,28 +66,26 @@ func (al *AgentLoop) UnmountHook(name string) {
al.hooks.Unmount(name)
}
// SubscribeEvents registers a subscriber for agent-loop events.
func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription {
if al == nil || al.eventBus == nil {
ch := make(chan Event)
close(ch)
return EventSubscription{C: ch}
// RuntimeEvents returns the root runtime event channel.
func (al *AgentLoop) RuntimeEvents() runtimeevents.EventChannel {
if al == nil || al.runtimeEvents == nil {
return nil
}
return al.eventBus.Subscribe(buffer)
return al.runtimeEvents.Channel()
}
// UnsubscribeEvents removes a previously registered event subscriber.
func (al *AgentLoop) UnsubscribeEvents(id uint64) {
if al == nil || al.eventBus == nil {
return
// RuntimeEventStats returns runtime event bus counters.
func (al *AgentLoop) RuntimeEventStats() runtimeevents.Stats {
if al == nil || al.runtimeEvents == nil {
return runtimeevents.Stats{Closed: true}
}
al.eventBus.Unsubscribe(id)
return al.runtimeEvents.Stats()
}
// EventDrops returns the number of dropped events for the given kind.
func (al *AgentLoop) EventDrops(kind EventKind) int64 {
if al == nil || al.eventBus == nil {
return 0
// RuntimeEventBus returns the runtime event bus used by the agent loop.
func (al *AgentLoop) RuntimeEventBus() runtimeevents.Bus {
if al == nil {
return nil
}
return al.eventBus.Dropped(kind)
return al.runtimeEvents
}
+40 -12
View File
@@ -13,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/skills"
@@ -24,6 +25,7 @@ func NewAgentLoop(
cfg *config.Config,
msgBus *bus.MessageBus,
provider providers.LLMProvider,
opts ...AgentLoopOption,
) *AgentLoop {
registry := NewAgentRegistry(cfg, provider)
@@ -47,8 +49,6 @@ func NewAgentLoop(
stateManager = state.NewManager(defaultAgent.Workspace)
}
eventBus := NewEventBus()
// Determine worker pool size from config (default: 1 = sequential)
workerPoolSize := cfg.Agents.Defaults.MaxParallelTurns
if workerPoolSize <= 0 {
@@ -56,18 +56,28 @@ func NewAgentLoop(
}
al := &AgentLoop{
bus: msgBus,
cfg: cfg,
registry: registry,
state: stateManager,
eventBus: eventBus,
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
workerSem: make(chan struct{}, workerPoolSize),
bus: msgBus,
cfg: cfg,
registry: registry,
state: stateManager,
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
workerSem: make(chan struct{}, workerPoolSize),
ownsRuntimeEvents: true,
}
for _, opt := range opts {
if opt != nil {
opt(al)
}
}
if al.runtimeEvents == nil {
al.runtimeEvents = runtimeevents.NewBus()
al.ownsRuntimeEvents = true
}
al.refreshRuntimeEventLogger(cfg)
al.providerFactory = providers.CreateProviderFromConfig
al.hooks = NewHookManager(eventBus)
al.hooks = NewHookManager(al.runtimeEvents.Channel())
configureHookManagerFromConfig(al.hooks, cfg)
al.contextManager = al.resolveContextManager()
@@ -128,6 +138,9 @@ func registerSharedTools(
if cfg.Tools.IsToolEnabled("spi") {
agent.Tools.Register(tools.NewSPITool())
}
if cfg.Tools.IsToolEnabled("serial") {
agent.Tools.Register(tools.NewSerialTool())
}
// Message tool
if cfg.Tools.IsToolEnabled("message") {
@@ -324,5 +337,20 @@ func registerSharedTools(
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
}
// Register delegate tool for multi-agent setups.
// Auto-enabled when multiple agents exist. Delegation uses the SubTurn
// mechanism directly (not SubagentManager) and is independent of the
// subagent tool.
if len(registry.ListAgentIDs()) > 1 {
delegateTool := tools.NewDelegateTool()
delegateTool.SetSpawner(NewSubTurnSpawner(al))
currentAgentID := agentID
delegateTool.SetSelfAgentID(currentAgentID)
delegateTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(delegateTool)
}
}
}
+2 -1
View File
@@ -97,7 +97,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
}
al.mcp.initOnce.Do(func() {
mcpManager := mcp.NewManager()
mcpManager := mcp.NewManager(mcp.WithRuntimeEvents(al.runtimeEvents))
defaultAgent := al.registry.GetDefaultAgent()
workspacePath := al.cfg.WorkspacePath()
@@ -164,6 +164,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
mcpTool.SetWorkspace(agent.Workspace)
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
mcpTool.SetEventPublisher(al.runtimeEvents)
if registerAsHidden {
agent.Tools.RegisterHidden(mcpTool)
+134 -63
View File
@@ -11,6 +11,7 @@ import (
"encoding/base64"
"io"
"os"
"regexp"
"strings"
"github.com/h2non/filetype"
@@ -20,24 +21,59 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
)
// genericPlaceholderRegex matches generic media placeholders emitted by various
// channels: [image], [image: photo], [image: filename.jpg] — but NOT path tags
// like [image:/path/to/file] (path tags have no space after the colon).
var (
imagePlaceholderRegex = regexp.MustCompile(`\[image(:\s+[^\]]*)?\]`)
audioPlaceholderRegex = regexp.MustCompile(`\[audio(:\s+[^\]]*)?\]`)
videoPlaceholderRegex = regexp.MustCompile(`\[video(:\s+[^\]]*)?\]`)
filePlaceholderRegex = regexp.MustCompile(`\[file(:\s+[^\]]*)?\]`)
)
// resolveMediaRefs resolves media:// refs in messages.
// Images are base64-encoded into the Media array for multimodal LLMs.
// Non-image files (documents, audio, video) have their local path injected
// into Content so the agent can access them via file tools like read_file.
// For user messages: images get path tags only ([image:/path]) so the LLM
// can decide whether to view them via load_image or operate on the file.
// For tool messages: images are base64-encoded and appended as a synthetic
// user message only after the contiguous tool-message block ends, so we don't
// break the tool-results-must-immediately-follow-assistant constraint that
// LLM APIs enforce.
// Non-image files always get path tags regardless of role.
// Returns a new slice; original messages are not mutated.
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
if store == nil {
return messages
}
result := make([]providers.Message, len(messages))
copy(result, messages)
result := make([]providers.Message, 0, len(messages))
var pendingToolImages []string
for idx, m := range messages {
// When leaving a tool-message block, flush any accumulated images
// as a synthetic user message.
if m.Role != "tool" && len(pendingToolImages) > 0 {
result = append(result, providers.Message{
Role: "user",
Content: "[Loaded image from tool result above]",
Media: pendingToolImages,
})
pendingToolImages = nil
}
for i, m := range result {
if len(m.Media) == 0 {
result = append(result, m)
if idx == len(messages)-1 && len(pendingToolImages) > 0 {
result = append(result, providers.Message{
Role: "user",
Content: "[Loaded image from tool result above]",
Media: pendingToolImages,
})
pendingToolImages = nil
}
continue
}
msg := m
resolved := make([]string, 0, len(m.Media))
var pathTags []string
@@ -66,27 +102,77 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxS
}
mime := detectMIME(localPath, meta)
pathTags = append(pathTags, buildPathTag(mime, localPath))
if strings.HasPrefix(mime, "image/") {
if m.Role == "tool" && strings.HasPrefix(mime, "image/") {
dataURL := encodeImageToDataURL(localPath, mime, info, maxSize)
if dataURL != "" {
resolved = append(resolved, dataURL)
pendingToolImages = append(pendingToolImages, dataURL)
}
continue
}
pathTags = append(pathTags, buildPathTag(mime, localPath))
}
result[i].Media = resolved
msg.Media = resolved
if len(pathTags) > 0 {
result[i].Content = injectPathTags(result[i].Content, pathTags)
msg.Content = injectPathTags(msg.Content, pathTags)
}
result = append(result, msg)
// If this is the last message and we have pending images, flush them.
if idx == len(messages)-1 && len(pendingToolImages) > 0 {
result = append(result, providers.Message{
Role: "user",
Content: "[Loaded image from tool result above]",
Media: pendingToolImages,
})
pendingToolImages = nil
}
}
return result
}
// encodeImageToDataURL base64-encodes an image file into a data URL.
// Returns empty string if the file exceeds maxSize or encoding fails.
func encodeImageToDataURL(localPath, mime string, info os.FileInfo, maxSize int) string {
if info.Size() > int64(maxSize) {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxSize,
})
return ""
}
f, err := os.Open(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to open media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
return ""
}
defer f.Close()
prefix := "data:" + mime + ";base64,"
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
var buf bytes.Buffer
buf.Grow(len(prefix) + encodedLen)
buf.WriteString(prefix)
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
if _, err := io.Copy(encoder, f); err != nil {
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
return ""
}
encoder.Close()
return buf.String()
}
func buildArtifactTags(store media.MediaStore, refs []string) []string {
if store == nil || len(refs) == 0 {
return nil
@@ -137,51 +223,12 @@ func detectMIME(localPath string, meta media.MediaMeta) string {
return kind.MIME.Value
}
// encodeImageToDataURL base64-encodes an image file into a data URL.
// Returns empty string if the file exceeds maxSize or encoding fails.
func encodeImageToDataURL(localPath, mime string, info os.FileInfo, maxSize int) string {
if info.Size() > int64(maxSize) {
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
"path": localPath,
"size": info.Size(),
"max_size": maxSize,
})
return ""
}
f, err := os.Open(localPath)
if err != nil {
logger.WarnCF("agent", "Failed to open media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
return ""
}
defer f.Close()
prefix := "data:" + mime + ";base64,"
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
var buf bytes.Buffer
buf.Grow(len(prefix) + encodedLen)
buf.WriteString(prefix)
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
if _, err := io.Copy(encoder, f); err != nil {
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
"path": localPath,
"error": err.Error(),
})
return ""
}
encoder.Close()
return buf.String()
}
// buildPathTag creates a structured tag exposing the local file path.
// Tag type is derived from MIME: [audio:/path], [video:/path], or [file:/path].
// Tag type is derived from MIME: [image:/path], [audio:/path], [video:/path], or [file:/path].
func buildPathTag(mime, localPath string) string {
switch {
case strings.HasPrefix(mime, "image/"):
return "[image:" + localPath + "]"
case strings.HasPrefix(mime, "audio/"):
return "[audio:" + localPath + "]"
case strings.HasPrefix(mime, "video/"):
@@ -192,22 +239,41 @@ func buildPathTag(mime, localPath string) string {
}
// injectPathTags replaces generic media tags in content with path-bearing versions,
// or appends if no matching generic tag is found.
// or appends if no matching generic tag is found. Channels emit a few different
// placeholder formats — [image], [image: photo], [image: filename.jpg] — so we
// match all of them via regex while leaving path tags ([image:/path]) untouched.
//
// When content is structured data (e.g., JSON from Feishu interactive cards or
// post messages), tags are only injected via placeholder replacement — never
// appended — to avoid corrupting the payload.
func injectPathTags(content string, tags []string) string {
isStructured := looksLikeJSON(content)
for _, tag := range tags {
var generic string
var pattern *regexp.Regexp
switch {
case strings.HasPrefix(tag, "[image:"):
pattern = imagePlaceholderRegex
case strings.HasPrefix(tag, "[audio:"):
generic = "[audio]"
pattern = audioPlaceholderRegex
case strings.HasPrefix(tag, "[video:"):
generic = "[video]"
pattern = videoPlaceholderRegex
case strings.HasPrefix(tag, "[file:"):
generic = "[file]"
pattern = filePlaceholderRegex
}
if generic != "" && strings.Contains(content, generic) {
content = strings.Replace(content, generic, tag, 1)
} else if content == "" {
if pattern != nil {
if loc := pattern.FindStringIndex(content); loc != nil {
content = content[:loc[0]] + tag + content[loc[1]:]
continue
}
}
if isStructured {
content = tag + "\n" + content
continue
}
if content == "" {
content = tag
} else {
content += " " + tag
@@ -215,3 +281,8 @@ func injectPathTags(content string, tags []string) string {
}
return content
}
func looksLikeJSON(s string) bool {
s = strings.TrimSpace(s)
return len(s) > 1 && s[0] == '{'
}
+20
View File
@@ -0,0 +1,20 @@
package agent
import runtimeevents "github.com/sipeed/picoclaw/pkg/events"
// AgentLoopOption configures an AgentLoop at construction time.
type AgentLoopOption func(*AgentLoop)
// WithRuntimeEvents injects the runtime event bus used for new observation APIs.
//
// The injected bus is treated as externally owned and will not be closed by
// AgentLoop.Close. Passing nil leaves the default owned runtime bus enabled.
func WithRuntimeEvents(bus runtimeevents.Bus) AgentLoopOption {
return func(al *AgentLoop) {
if bus == nil {
return
}
al.runtimeEvents = bus
al.ownsRuntimeEvents = false
}
}
+31 -15
View File
@@ -44,11 +44,36 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
return
}
// Drain steering queue using existing Continue mechanism
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
if continueErr != nil {
logger.WarnCF("agent", "Failed to continue queued steering",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"error": continueErr.Error(),
})
} else if continued != "" {
finalResponse = continued
}
// Publish final response
if finalResponse != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
}
}
func (al *AgentLoop) drainQueuedSteeringContinuations(
ctx context.Context,
target *continuationTarget,
) (string, error) {
if target == nil {
return "", nil
}
finalResponse := ""
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
// Check for context cancellation between iterations
if ctx.Err() != nil {
return
if err := ctx.Err(); err != nil {
return finalResponse, err
}
logger.InfoCF("agent", "Continuing queued steering after turn end",
@@ -61,13 +86,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
if continueErr != nil {
logger.WarnCF("agent", "Failed to continue queued steering",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"error": continueErr.Error(),
})
break
return finalResponse, continueErr
}
if continued == "" {
break
@@ -75,10 +94,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
finalResponse = continued
}
// Publish final response
if finalResponse != "" {
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
}
return finalResponse, nil
}
func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
+122
View File
@@ -0,0 +1,122 @@
package agent
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/commands"
)
func (al *AgentLoop) tryHandleStopCommand(
ctx context.Context,
msg bus.InboundMessage,
sessionKey string,
) bool {
cmdName, ok := commands.CommandName(msg.Content)
if !ok || cmdName != "stop" {
return false
}
result, err := al.stopActiveTurnForSession(sessionKey)
// This function is only called when loaded=true (another turn already
// claimed this session). If stopActiveTurnForSession found a pending
// placeholder but didn't stop it, that placeholder belongs to the other
// message's worker which hasn't started yet — arm a pending stop so the
// worker will bail when it checks before running.
if err == nil && !result.Stopped {
if ts := al.getActiveTurnState(sessionKey); ts != nil {
snap := ts.snapshot()
if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) {
al.markPendingStop(sessionKey)
result.Stopped = true
}
}
}
reply := commands.FormatStopReply(result)
if err != nil {
reply = "Failed to stop task: " + err.Error()
}
if al.channelManager != nil {
al.channelManager.InvokeTypingStop(msg.Channel, msg.ChatID)
}
al.resetMessageToolRound(sessionKey)
al.PublishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, sessionKey, reply)
return true
}
func (al *AgentLoop) stopActiveTurnForSession(sessionKey string) (commands.StopResult, error) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return commands.StopResult{}, fmt.Errorf("session key is required")
}
result := commands.StopResult{}
cleared := al.clearSteeringMessagesForScope(sessionKey)
al.clearPendingSkills(sessionKey)
ts := al.getActiveTurnState(sessionKey)
if ts == nil {
result.Stopped = cleared > 0
return result, nil
}
snap := ts.snapshot()
result.TaskName = snap.UserMessage
if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) {
// A pending placeholder means this session is either idle (our own
// placeholder from the /stop command) or another message is queued but
// hasn't started yet. In both cases, we don't arm a pending stop here;
// the caller (tryHandleStopCommand) handles the "another message queued"
// case explicitly, since it knows loaded=true.
return result, nil
}
if err := al.HardAbort(sessionKey); err != nil {
if al.getActiveTurnState(sessionKey) == nil {
result.Stopped = cleared > 0
return result, nil
}
return commands.StopResult{}, err
}
result.Stopped = true
return result, nil
}
func (al *AgentLoop) markPendingStop(sessionKey string) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return
}
al.pendingStops.Store(sessionKey, struct{}{})
}
func (al *AgentLoop) takePendingStop(sessionKey string) bool {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return false
}
_, ok := al.pendingStops.LoadAndDelete(sessionKey)
return ok
}
func (al *AgentLoop) resetMessageToolRound(sessionKey string) {
if strings.TrimSpace(sessionKey) == "" {
return
}
if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
if tool, ok := agent.Tools.Get("message"); ok {
if resetter, ok := tool.(interface{ ResetSentInRound(sessionKey string) }); ok {
resetter.ResetSentInRound(sessionKey)
}
}
}
}
}
+238 -30
View File
@@ -19,6 +19,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -1781,17 +1782,22 @@ func (m *artifactThenSendProvider) Chat(
if messages[i].Role != "tool" {
continue
}
start := strings.Index(messages[i].Content, "[file:")
if start < 0 {
continue
for _, prefix := range []string{"[image:", "[file:", "[audio:", "[video:"} {
start := strings.Index(messages[i].Content, prefix)
if start < 0 {
continue
}
rest := messages[i].Content[start+len(prefix):]
end := strings.Index(rest, "]")
if end < 0 {
continue
}
artifactPath = rest[:end]
break
}
rest := messages[i].Content[start+len("[file:"):]
end := strings.Index(rest, "]")
if end < 0 {
continue
if artifactPath != "" {
break
}
artifactPath = rest[:end]
break
}
if artifactPath == "" {
return nil, fmt.Errorf("provider did not receive artifact path in tool result")
@@ -4656,7 +4662,7 @@ func TestRun_PicoToolFeedbackSuppressesDuplicateInterimAssistantContent(t *testi
}
}
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
func TestResolveMediaRefs_ImageInjectsPathTag(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
@@ -4684,15 +4690,110 @@ func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (images use path tags), got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
localPath, _, _ := store.ResolveWithMeta(ref)
expectedContent := "describe this [image:" + localPath + "]"
if result[0].Content != expectedContent {
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
}
}
func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
func TestResolveMediaRefs_ToolRoleImageAppendedAsUserMessage(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
pngPath := filepath.Join(dir, "tool-result.png")
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
0x00, 0x00, 0x00, 0x0D, // IHDR length
0x49, 0x48, 0x44, 0x52, // "IHDR"
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
0x00, 0x00, 0x00, // no interlace
0x90, 0x77, 0x53, 0xDE, // CRC
}
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
t.Fatal(err)
}
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
messages := []providers.Message{
{Role: "tool", Content: "Image loaded", Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
// Tool message should have path tag but no base64
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media in tool message, got %d", len(result[0].Media))
}
localPath, _, _ := store.ResolveWithMeta(ref)
if !strings.Contains(result[0].Content, "[image:"+localPath+"]") {
t.Fatalf("expected image path tag in tool content, got %q", result[0].Content)
}
// A synthetic user message with base64 should follow
if len(result) != 2 {
t.Fatalf("expected 2 messages (tool + synthetic user), got %d", len(result))
}
if result[1].Role != "user" {
t.Fatalf("expected synthetic message role=user, got %q", result[1].Role)
}
if len(result[1].Media) != 1 {
t.Fatalf("expected 1 base64 media in synthetic user message, got %d", len(result[1].Media))
}
if !strings.HasPrefix(result[1].Media[0], "data:image/png;base64,") {
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[1].Media[0][:40])
}
}
func TestResolveMediaRefs_MultiToolCallPreservesOrdering(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
// Create image for tool #1
pngPath := filepath.Join(dir, "loaded.png")
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
}
os.WriteFile(pngPath, pngHeader, 0o644)
imgRef, _ := store.Store(pngPath, media.MediaMeta{}, "test")
// Simulate: assistant called load_image + read_file, two tool results follow
messages := []providers.Message{
{Role: "assistant", Content: "Let me load the image and read the file."},
{Role: "tool", Content: "Image loaded [image: photo]", Media: []string{imgRef}},
{Role: "tool", Content: "file contents here"},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
// assistant, tool#1, tool#2 must remain contiguous — no user in between
if result[0].Role != "assistant" {
t.Fatalf("result[0] expected assistant, got %q", result[0].Role)
}
if result[1].Role != "tool" {
t.Fatalf("result[1] expected tool, got %q", result[1].Role)
}
if result[2].Role != "tool" {
t.Fatalf("result[2] expected tool, got %q", result[2].Role)
}
// Synthetic user message should come AFTER the tool block
if len(result) != 4 {
t.Fatalf("expected 4 messages (assistant + 2 tool + synthetic user), got %d", len(result))
}
if result[3].Role != "user" {
t.Fatalf("result[3] expected user, got %q", result[3].Role)
}
if len(result[3].Media) != 1 || !strings.HasPrefix(result[3].Media[0], "data:image/png;base64,") {
t.Fatal("expected synthetic user message to contain base64 image")
}
}
func TestResolveMediaRefs_OversizedImageSkipsBase64KeepsPathTag(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
@@ -4714,6 +4815,11 @@ func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
}
localPath, _, _ := store.ResolveWithMeta(ref)
expected := "hi [image:" + localPath + "]"
if result[0].Content != expected {
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
}
}
func TestResolveMediaRefs_UnknownTypeInjectsPath(t *testing.T) {
@@ -4791,11 +4897,13 @@ func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 media, got %d", len(result[0].Media))
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (images use path tags), got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
localPath, _, _ := store.ResolveWithMeta(ref)
expectedContent := "hi [image:" + localPath + "]"
if result[0].Content != expectedContent {
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
}
}
@@ -4885,6 +4993,98 @@ func TestResolveMediaRefs_NoGenericTagAppendsPath(t *testing.T) {
}
}
func TestInjectPathTags_HandlesVariousChannelPlaceholders(t *testing.T) {
cases := []struct {
name string
content string
tag string
want string
}{
// Telegram / Feishu format
{"image_photo", "[image: photo]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
// WeCom / WeChat / Line format
{"bare_image", "[image]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
// QQ / Discord format with filename
{"image_filename", "[image: pic.jpg]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
{"audio_with_filename", "[audio: voice.m4a]", "[audio:/tmp/a.m4a]", "[audio:/tmp/a.m4a]"},
{"bare_audio", "[audio]", "[audio:/tmp/a.m4a]", "[audio:/tmp/a.m4a]"},
{"bare_video", "[video]", "[video:/tmp/v.mp4]", "[video:/tmp/v.mp4]"},
{"bare_file", "[file]", "[file:/tmp/f.pdf]", "[file:/tmp/f.pdf]"},
// Mixed surrounding text
{
"with_text",
"hello [image] world",
"[image:/tmp/p.png]",
"hello [image:/tmp/p.png] world",
},
// No placeholder — append
{"no_placeholder", "hello world", "[image:/tmp/p.png]", "hello world [image:/tmp/p.png]"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := injectPathTags(tc.content, []string{tc.tag})
if got != tc.want {
t.Errorf("expected %q, got %q", tc.want, got)
}
})
}
}
func TestInjectPathTags_DoesNotReplacePathTag(t *testing.T) {
// If content already contains a path tag, we must not touch it.
content := "see [image:/already/placed.png] thanks"
got := injectPathTags(content, []string{"[image:/new/path.png]"})
want := "see [image:/already/placed.png] thanks [image:/new/path.png]"
if got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}
func TestInjectPathTags_PrependsForJSONContent(t *testing.T) {
jsonContent := `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`
got := injectPathTags(jsonContent, []string{"[image:/tmp/photo.png]"})
want := "[image:/tmp/photo.png]\n" + jsonContent
if got != want {
t.Fatalf("expected tag prepended to JSON, got %q", got)
}
}
func TestInjectPathTags_BracketTextNotTreatedAsJSON(t *testing.T) {
content := "[update] see attached report"
got := injectPathTags(content, []string{"[file:/tmp/report.pdf]"})
want := "[update] see attached report [file:/tmp/report.pdf]"
if got != want {
t.Fatalf("expected tag appended to bracket text, got %q", got)
}
}
func TestResolveMediaRefs_JSONContentPrependsPathTag(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
pngPath := filepath.Join(dir, "card_img.png")
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
}
os.WriteFile(pngPath, pngHeader, 0o644)
ref, _ := store.Store(pngPath, media.MediaMeta{ContentType: "image/png"}, "test")
jsonContent := `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`
messages := []providers.Message{
{Role: "user", Content: jsonContent, Media: []string{ref}},
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
want := "[image:" + pngPath + "]\n" + jsonContent
if result[0].Content != want {
t.Fatalf("expected path tag prepended to JSON content, got %q", result[0].Content)
}
}
func TestResolveMediaRefs_EmptyContentGetsPathTag(t *testing.T) {
store := media.NewFileMediaStore()
dir := t.TempDir()
@@ -4928,13 +5128,12 @@ func TestResolveMediaRefs_MixedImageAndFile(t *testing.T) {
}
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
if len(result[0].Media) != 1 {
t.Fatalf("expected 1 media (image only), got %d", len(result[0].Media))
if len(result[0].Media) != 0 {
t.Fatalf("expected 0 media (all types use path tags), got %d", len(result[0].Media))
}
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
t.Fatal("expected image to be base64 encoded")
}
expectedContent := "check these [file:" + pdfPath + "]"
imgLocalPath, _, _ := store.ResolveWithMeta(imgRef)
pdfLocalPath, _, _ := store.ResolveWithMeta(fileRef)
expectedContent := "check these [file:" + pdfLocalPath + "] [image:" + imgLocalPath + "]"
if result[0].Content != expectedContent {
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
}
@@ -5258,6 +5457,7 @@ func TestParallelMessageProcessing_SameSessionProcessedSequentially(t *testing.T
var mu sync.Mutex
turnIDs := make(map[string]bool)
var wg sync.WaitGroup
var firstResponse sync.Once
wg.Add(1) // Only 1 turn should be created for same session
cfg := &config.Config{
@@ -5280,19 +5480,27 @@ func TestParallelMessageProcessing_SameSessionProcessedSequentially(t *testing.T
al := NewAgentLoop(cfg, msgBus, &concurrentMockProvider{
responseFunc: func(callID int) string {
wg.Done()
firstResponse.Do(func() {
wg.Done()
})
return "ok"
},
})
defer al.Close()
sub := al.SubscribeEvents(64)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
64,
runtimeevents.KindAgentTurnStart,
)
defer closeRuntimeEvents()
go func() {
for evt := range sub.C {
if evt.Kind == EventKindTurnStart {
for evt := range runtimeCh {
if evt.Kind == runtimeevents.KindAgentTurnStart {
mu.Lock()
turnIDs[evt.Meta.TurnID] = true
turnIDs[evt.Scope.TurnID] = true
mu.Unlock()
}
}
+9 -11
View File
@@ -4,7 +4,6 @@ package agent
import (
"context"
"encoding/json"
"fmt"
"maps"
"path/filepath"
@@ -171,15 +170,8 @@ func toolFeedbackExplanationFromMessages(messages []providers.Message) string {
}
func toolFeedbackArgsPreview(args map[string]any, maxLen int) string {
if args == nil {
args = map[string]any{}
}
argsJSON, err := json.MarshalIndent(args, "", " ")
if err != nil {
return utils.Truncate(fmt.Sprintf("%v", args), maxLen)
}
return utils.Truncate(string(argsJSON), maxLen)
argsJSON := utils.FormatArgsJSON(args, true, false)
return utils.Truncate(argsJSON, maxLen)
}
func shouldPublishToolFeedback(cfg *config.Config, ts *turnState) bool {
@@ -293,6 +285,12 @@ func inferMediaType(filename, contentType string) string {
ct := strings.ToLower(contentType)
fn := strings.ToLower(filename)
// SVG is an image MIME type, but raster-only delivery endpoints such as
// Telegram SendPhoto reject it. Treat it as a file/document instead.
if strings.HasPrefix(ct, "image/svg") || filepath.Ext(fn) == ".svg" {
return "file"
}
if strings.HasPrefix(ct, "image/") {
return "image"
}
@@ -306,7 +304,7 @@ func inferMediaType(filename, contentType string) string {
// Fallback: infer from extension
ext := filepath.Ext(fn)
switch ext {
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg":
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp":
return "image"
case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus":
return "audio"
+76
View File
@@ -0,0 +1,76 @@
package agent
import "testing"
func TestInferMediaType(t *testing.T) {
tests := []struct {
name string
filename string
contentType string
want string
}{
{
name: "png content type",
filename: "diagram",
contentType: "image/png",
want: "image",
},
{
name: "jpeg extension fallback",
filename: "photo.JPG",
contentType: "",
want: "image",
},
{
name: "svg content type is file",
filename: "diagram",
contentType: "image/svg+xml",
want: "file",
},
{
name: "svg content type with parameters is file",
filename: "diagram",
contentType: "image/svg+xml; charset=utf-8",
want: "file",
},
{
name: "svg extension fallback is file",
filename: "diagram.SVG",
contentType: "",
want: "file",
},
{
name: "audio content type",
filename: "voice",
contentType: "audio/ogg",
want: "audio",
},
{
name: "ogg application content type",
filename: "voice.ogg",
contentType: "application/ogg",
want: "audio",
},
{
name: "video extension fallback",
filename: "clip.MP4",
contentType: "",
want: "video",
},
{
name: "unknown type",
filename: "archive.bin",
contentType: "application/octet-stream",
want: "file",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := inferMediaType(tt.filename, tt.contentType)
if got != tt.want {
t.Fatalf("inferMediaType(%q, %q) = %q, want %q", tt.filename, tt.contentType, got, tt.want)
}
})
}
}
+3 -2
View File
@@ -7,6 +7,7 @@ import (
"sync"
"time"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -41,7 +42,7 @@ func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) e
// Sync emergency compression — budget exceeded.
if result, ok := m.forceCompression(req.SessionKey); ok {
m.al.emitEvent(
EventKindContextCompress,
runtimeevents.KindAgentContextCompress,
m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
ContextCompressPayload{
Reason: req.Reason,
@@ -246,7 +247,7 @@ func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey
agent.Sessions.TruncateHistory(sessionKey, keepCount)
agent.Sessions.Save(sessionKey)
m.al.emitEvent(
EventKindSessionSummarize,
runtimeevents.KindAgentSessionSummarize,
m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
SessionSummarizePayload{
SummarizedMessages: len(validMessages),
+29 -14
View File
@@ -12,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -305,8 +306,13 @@ func TestLegacyCompact_Overflow(t *testing.T) {
}
defaultAgent.Sessions.SetHistory("session-overflow", history)
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentContextCompress,
)
defer closeRuntimeEvents()
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-overflow",
@@ -329,8 +335,8 @@ func TestLegacyCompact_Overflow(t *testing.T) {
}
// Event should carry the proactive reason
events := collectEventStream(sub.C)
compressEvt, ok := findEvent(events, EventKindContextCompress)
events := collectRuntimeEventStream(runtimeCh)
compressEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentContextCompress)
if !ok {
t.Fatal("expected context compress event")
}
@@ -361,8 +367,13 @@ func TestLegacyCompact_Overflow_ProactiveReason(t *testing.T) {
}
defaultAgent.Sessions.SetHistory("session-proactive", history)
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentContextCompress,
)
defer closeRuntimeEvents()
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-proactive",
@@ -372,8 +383,8 @@ func TestLegacyCompact_Overflow_ProactiveReason(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
events := collectEventStream(sub.C)
compressEvt, ok := findEvent(events, EventKindContextCompress)
events := collectRuntimeEventStream(runtimeCh)
compressEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentContextCompress)
if !ok {
t.Fatal("expected context compress event")
}
@@ -483,6 +494,14 @@ func TestLegacyCompact_PostTurn_ExceedsMessageThreshold(t *testing.T) {
}
defaultAgent.Sessions.SetHistory("session-threshold", history)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentSessionSummarize,
)
defer closeRuntimeEvents()
err := al.contextManager.Compact(context.Background(), &CompactRequest{
SessionKey: "session-threshold",
Reason: ContextCompressReasonSummarize,
@@ -491,12 +510,8 @@ func TestLegacyCompact_PostTurn_ExceedsMessageThreshold(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
// Wait for async summarization to complete via event
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
waitForEvent(t, sub.C, 5*time.Second, func(evt Event) bool {
return evt.Kind == EventKindSessionSummarize
waitForRuntimeEvent(t, runtimeCh, 5*time.Second, func(evt runtimeevents.Event) bool {
return evt.Kind == runtimeevents.KindAgentSessionSummarize
})
newHistory := defaultAgent.Sessions.GetHistory("session-threshold")
+171
View File
@@ -0,0 +1,171 @@
package agent
import "time"
// TurnEndStatus describes the terminal state of a turn.
type TurnEndStatus string
const (
// TurnEndStatusCompleted indicates the turn finished normally.
TurnEndStatusCompleted TurnEndStatus = "completed"
// TurnEndStatusError indicates the turn ended because of an error.
TurnEndStatusError TurnEndStatus = "error"
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
TurnEndStatusAborted TurnEndStatus = "aborted"
)
// TurnStartPayload describes the start of a turn.
type TurnStartPayload struct {
UserMessage string
MediaCount int
}
// TurnEndPayload describes the completion of a turn.
type TurnEndPayload struct {
Status TurnEndStatus
Iterations int
Duration time.Duration
FinalContentLen int
}
// LLMRequestPayload describes an outbound LLM request.
type LLMRequestPayload struct {
Model string
MessagesCount int
ToolsCount int
MaxTokens int
Temperature float64
}
// LLMResponsePayload describes an inbound LLM response.
type LLMResponsePayload struct {
ContentLen int
ToolCalls int
HasReasoning bool
}
// LLMDeltaPayload describes a streamed LLM delta.
type LLMDeltaPayload struct {
ContentDeltaLen int
ReasoningDeltaLen int
}
// LLMRetryPayload describes a retry of an LLM request.
type LLMRetryPayload struct {
Attempt int
MaxRetries int
Reason string
Error string
Backoff time.Duration
}
// ContextCompressReason identifies why emergency compression ran.
type ContextCompressReason string
const (
// ContextCompressReasonProactive indicates compression before the first LLM call.
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
// ContextCompressReasonRetry indicates compression during context-error retry handling.
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
// ContextCompressReasonSummarize indicates post-turn async summarization.
ContextCompressReasonSummarize ContextCompressReason = "summarize"
)
// ContextCompressPayload describes a forced history compression.
type ContextCompressPayload struct {
Reason ContextCompressReason
DroppedMessages int
RemainingMessages int
}
// SessionSummarizePayload describes a completed async session summarization.
type SessionSummarizePayload struct {
SummarizedMessages int
KeptMessages int
SummaryLen int
OmittedOversized bool
}
// ToolExecStartPayload describes a tool execution request.
type ToolExecStartPayload struct {
Tool string
Arguments map[string]any
}
// ToolExecEndPayload describes the outcome of a tool execution.
type ToolExecEndPayload struct {
Tool string
Duration time.Duration
ForLLMLen int
ForUserLen int
IsError bool
Async bool
}
// ToolExecSkippedPayload describes a skipped tool call.
type ToolExecSkippedPayload struct {
Tool string
Reason string
}
// SteeringInjectedPayload describes steering messages appended before the next LLM call.
type SteeringInjectedPayload struct {
Count int
TotalContentLen int
}
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
type FollowUpQueuedPayload struct {
SourceTool string
ContentLen int
}
type InterruptKind string
const (
InterruptKindSteering InterruptKind = "steering"
InterruptKindGraceful InterruptKind = "graceful"
InterruptKindHard InterruptKind = "hard_abort"
)
// InterruptReceivedPayload describes accepted turn-control input.
type InterruptReceivedPayload struct {
Kind InterruptKind
Role string
ContentLen int
QueueDepth int
HintLen int
}
// SubTurnSpawnPayload describes the creation of a child turn.
type SubTurnSpawnPayload struct {
AgentID string
Label string
ParentTurnID string
}
// SubTurnEndPayload describes the completion of a child turn.
type SubTurnEndPayload struct {
AgentID string
Status string
}
// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
type SubTurnResultDeliveredPayload struct {
TargetChannel string
TargetChatID string
ContentLen int
}
// SubTurnOrphanPayload describes a sub-turn result that could not be delivered.
type SubTurnOrphanPayload struct {
ParentTurnID string
ChildTurnID string
Reason string
}
// ErrorPayload describes an execution error inside the agent loop.
type ErrorPayload struct {
Stage string
Message string
}
-121
View File
@@ -1,121 +0,0 @@
package agent
import (
"sync"
"sync/atomic"
"time"
)
const defaultEventSubscriberBuffer = 16
// EventSubscription identifies a subscriber channel returned by EventBus.Subscribe.
type EventSubscription struct {
ID uint64
C <-chan Event
}
type eventSubscriber struct {
ch chan Event
}
// EventBus is a lightweight multi-subscriber broadcaster for agent-loop events.
type EventBus struct {
mu sync.RWMutex
subs map[uint64]eventSubscriber
nextID uint64
closed bool
dropped [eventKindCount]atomic.Int64
}
// NewEventBus creates a new in-process event broadcaster.
func NewEventBus() *EventBus {
return &EventBus{
subs: make(map[uint64]eventSubscriber),
}
}
// Subscribe registers a new subscriber with the requested channel buffer size.
// A non-positive buffer uses the default size.
func (b *EventBus) Subscribe(buffer int) EventSubscription {
if buffer <= 0 {
buffer = defaultEventSubscriberBuffer
}
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
ch := make(chan Event)
close(ch)
return EventSubscription{C: ch}
}
b.nextID++
id := b.nextID
ch := make(chan Event, buffer)
b.subs[id] = eventSubscriber{ch: ch}
return EventSubscription{ID: id, C: ch}
}
// Unsubscribe removes a subscriber and closes its channel.
func (b *EventBus) Unsubscribe(id uint64) {
b.mu.Lock()
defer b.mu.Unlock()
sub, ok := b.subs[id]
if !ok {
return
}
delete(b.subs, id)
close(sub.ch)
}
// Emit broadcasts an event to all current subscribers without blocking.
// When a subscriber channel is full, the event is dropped for that subscriber.
func (b *EventBus) Emit(evt Event) {
if evt.Time.IsZero() {
evt.Time = time.Now()
}
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return
}
for _, sub := range b.subs {
select {
case sub.ch <- evt:
default:
if evt.Kind < eventKindCount {
b.dropped[evt.Kind].Add(1)
}
}
}
}
// Dropped returns the number of dropped events for a given kind.
func (b *EventBus) Dropped(kind EventKind) int64 {
if kind >= eventKindCount {
return 0
}
return b.dropped[kind].Load()
}
// Close closes all subscriber channels and stops future broadcasts.
func (b *EventBus) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
b.closed = true
for id, sub := range b.subs {
close(sub.ch)
delete(b.subs, id)
}
}
+155 -132
View File
@@ -9,61 +9,94 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
func TestEventBus_SubscribeEmitUnsubscribeClose(t *testing.T) {
eventBus := NewEventBus()
sub := eventBus.Subscribe(1)
eventBus.Emit(Event{
Kind: EventKindTurnStart,
Meta: EventMeta{TurnID: "turn-1"},
})
select {
case evt := <-sub.C:
if evt.Kind != EventKindTurnStart {
t.Fatalf("expected %v, got %v", EventKindTurnStart, evt.Kind)
func TestAgentLoop_PublishesRuntimeEvents(t *testing.T) {
runtimeBus := runtimeevents.NewBus()
al := &AgentLoop{
runtimeEvents: runtimeBus,
}
defer func() {
if err := runtimeBus.Close(); err != nil {
t.Errorf("runtime bus close failed: %v", err)
}
if evt.Meta.TurnID != "turn-1" {
t.Fatalf("expected turn id turn-1, got %q", evt.Meta.TurnID)
}()
runtimeSub, runtimeCh, err := al.RuntimeEvents().OfKind(runtimeevents.KindAgentToolExecStart).SubscribeChan(
context.Background(),
runtimeevents.SubscribeOptions{Name: "runtime", Buffer: 1},
)
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
defer func() {
if err := runtimeSub.Close(); err != nil {
t.Errorf("runtime subscription close failed: %v", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}()
al.emitEvent(
runtimeevents.KindAgentToolExecStart,
HookMeta{
AgentID: "main",
TurnID: "turn-1",
ParentTurnID: "parent-turn",
SessionKey: "session-1",
Iteration: 2,
TracePath: "trace/root",
Source: "pipeline_execute",
turnContext: &TurnContext{
Inbound: &bus.InboundContext{
Channel: "cli",
Account: "default",
ChatID: "direct",
ChatType: "direct",
SenderID: "tester",
MessageID: "msg-1",
TopicID: "topic-1",
},
},
},
ToolExecStartPayload{Tool: "mock_custom", Arguments: map[string]any{"task": "ping"}},
)
runtimeEvt := receiveRuntimeEvent(t, runtimeCh)
if runtimeEvt.Kind != runtimeevents.KindAgentToolExecStart {
t.Fatalf("runtime kind = %q, want %q", runtimeEvt.Kind, runtimeevents.KindAgentToolExecStart)
}
eventBus.Unsubscribe(sub.ID)
if _, ok := <-sub.C; ok {
t.Fatal("expected subscriber channel to be closed after unsubscribe")
if runtimeEvt.Source != (runtimeevents.Source{Component: "agent", Name: "main"}) {
t.Fatalf("runtime source = %+v", runtimeEvt.Source)
}
eventBus.Close()
closedSub := eventBus.Subscribe(1)
if _, ok := <-closedSub.C; ok {
t.Fatal("expected closed bus to return a closed subscriber channel")
if runtimeEvt.Scope.AgentID != "main" ||
runtimeEvt.Scope.SessionKey != "session-1" ||
runtimeEvt.Scope.TurnID != "turn-1" ||
runtimeEvt.Scope.Channel != "cli" ||
runtimeEvt.Scope.Account != "default" ||
runtimeEvt.Scope.ChatID != "direct" ||
runtimeEvt.Scope.TopicID != "topic-1" ||
runtimeEvt.Scope.ChatType != "direct" ||
runtimeEvt.Scope.SenderID != "tester" ||
runtimeEvt.Scope.MessageID != "msg-1" {
t.Fatalf("runtime scope = %+v", runtimeEvt.Scope)
}
}
func TestEventBus_DropsWhenSubscriberIsFull(t *testing.T) {
eventBus := NewEventBus()
sub := eventBus.Subscribe(1)
defer eventBus.Unsubscribe(sub.ID)
start := time.Now()
for i := 0; i < 1000; i++ {
eventBus.Emit(Event{Kind: EventKindLLMRequest})
if runtimeEvt.Correlation.TraceID != "trace/root" ||
runtimeEvt.Correlation.ParentTurnID != "parent-turn" {
t.Fatalf("runtime correlation = %+v", runtimeEvt.Correlation)
}
if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
t.Fatalf("Emit took too long with a blocked subscriber: %s", elapsed)
if runtimeEvt.Attrs["agent_source"] != "pipeline_execute" || runtimeEvt.Attrs["iteration"] != 2 {
t.Fatalf("runtime attrs = %+v", runtimeEvt.Attrs)
}
if got := eventBus.Dropped(EventKindLLMRequest); got != 999 {
t.Fatalf("expected 999 dropped events, got %d", got)
payload, ok := runtimeEvt.Payload.(ToolExecStartPayload)
if !ok {
t.Fatalf("runtime payload = %T, want ToolExecStartPayload", runtimeEvt.Payload)
}
if payload.Tool != "mock_custom" {
t.Fatalf("runtime payload tool = %q, want mock_custom", payload.Tool)
}
}
@@ -127,8 +160,18 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
t.Fatal("expected default agent")
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
expectedKinds := []runtimeevents.Kind{
runtimeevents.KindAgentTurnStart,
runtimeevents.KindAgentLLMRequest,
runtimeevents.KindAgentLLMResponse,
runtimeevents.KindAgentToolExecStart,
runtimeevents.KindAgentToolExecEnd,
runtimeevents.KindAgentLLMRequest,
runtimeevents.KindAgentLLMResponse,
runtimeevents.KindAgentTurnEnd,
}
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(t, al, 16, expectedKinds...)
defer closeRuntimeEvents()
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
@@ -171,49 +214,36 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
t.Fatalf("expected final response 'done', got %q", response)
}
events := collectEventStream(sub.C)
events := collectRuntimeEventStream(runtimeCh)
if len(events) != 8 {
t.Fatalf("expected 8 events, got %d", len(events))
}
kinds := make([]EventKind, 0, len(events))
kinds := make([]runtimeevents.Kind, 0, len(events))
for _, evt := range events {
kinds = append(kinds, evt.Kind)
}
expectedKinds := []EventKind{
EventKindTurnStart,
EventKindLLMRequest,
EventKindLLMResponse,
EventKindToolExecStart,
EventKindToolExecEnd,
EventKindLLMRequest,
EventKindLLMResponse,
EventKindTurnEnd,
}
if !slices.Equal(kinds, expectedKinds) {
t.Fatalf("unexpected event sequence: got %v want %v", kinds, expectedKinds)
}
turnID := events[0].Meta.TurnID
turnID := events[0].Scope.TurnID
if turnID == "" {
t.Fatal("expected runtime events to include turn id")
}
for i, evt := range events {
if evt.Meta.TurnID != turnID {
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Meta.TurnID, turnID)
if evt.Scope.TurnID != turnID {
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Scope.TurnID, turnID)
}
if evt.Meta.SessionKey != "session-1" {
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
if evt.Scope.SessionKey != "session-1" {
t.Fatalf("event %d has session key %q, want session-1", i, evt.Scope.SessionKey)
}
if evt.Context == nil || evt.Context.Inbound == nil {
t.Fatalf("event %d missing inbound turn context", i)
if evt.Scope.Channel != "cli" || evt.Scope.ChatID != "direct" || evt.Scope.SenderID != "tester" {
t.Fatalf("event %d scope = %+v", i, evt.Scope)
}
if evt.Context.Inbound.Channel != "cli" || evt.Context.Inbound.SenderID != "tester" {
t.Fatalf("event %d inbound context = %+v", i, evt.Context.Inbound)
}
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
t.Fatalf("event %d missing route context: %+v", i, evt.Context.Route)
}
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "tester" {
t.Fatalf("event %d missing session scope: %+v", i, evt.Context.Scope)
if evt.Scope.AgentID != "main" {
t.Fatalf("event %d has agent id %q, want main", i, evt.Scope.AgentID)
}
}
@@ -309,8 +339,15 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
al.RegisterTool(tool1)
al.RegisterTool(tool2)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
32,
runtimeevents.KindAgentSteeringInjected,
runtimeevents.KindAgentToolExecSkipped,
runtimeevents.KindAgentInterruptReceived,
)
defer closeRuntimeEvents()
resultCh := make(chan string, 1)
go func() {
@@ -337,8 +374,8 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
t.Fatal("timeout waiting for steered response")
}
events := collectEventStream(sub.C)
steeringEvt, ok := findEvent(events, EventKindSteeringInjected)
events := collectRuntimeEventStream(runtimeCh)
steeringEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentSteeringInjected)
if !ok {
t.Fatal("expected steering injected event")
}
@@ -350,7 +387,7 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
t.Fatalf("expected 1 steering message, got %d", steeringPayload.Count)
}
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
skippedEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecSkipped)
if !ok {
t.Fatal("expected skipped tool event")
}
@@ -362,7 +399,7 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
t.Fatalf("expected skipped tool_two, got %q", skippedPayload.Tool)
}
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
interruptEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
@@ -420,8 +457,14 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
{Role: "user", Content: "Trigger message"},
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentLLMRetry,
runtimeevents.KindAgentContextCompress,
)
defer closeRuntimeEvents()
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
@@ -439,8 +482,8 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
t.Fatalf("expected retry success, got %q", resp)
}
events := collectEventStream(sub.C)
retryEvt, ok := findEvent(events, EventKindLLMRetry)
events := collectRuntimeEventStream(runtimeCh)
retryEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentLLMRetry)
if !ok {
t.Fatal("expected llm retry event")
}
@@ -455,7 +498,7 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
t.Fatalf("expected retry attempt 1, got %d", retryPayload.Attempt)
}
compressEvt, ok := findEvent(events, EventKindContextCompress)
compressEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentContextCompress)
if !ok {
t.Fatal("expected context compress event")
}
@@ -508,14 +551,19 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
{Role: "assistant", Content: "Answer three"},
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentSessionSummarize,
)
defer closeRuntimeEvents()
lcm := &legacyContextManager{al: al}
lcm.summarizeSession(defaultAgent, "session-1")
events := collectEventStream(sub.C)
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
events := collectRuntimeEventStream(runtimeCh)
summaryEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentSessionSummarize)
if !ok {
t.Fatal("expected session summarize event")
}
@@ -575,8 +623,13 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
t.Fatal("expected default agent")
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
32,
runtimeevents.KindAgentFollowUpQueued,
)
defer closeRuntimeEvents()
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
@@ -600,8 +653,8 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
t.Fatal("timeout waiting for async tool completion")
}
followUpEvt := waitForEvent(t, sub.C, 2*time.Second, func(evt Event) bool {
return evt.Kind == EventKindFollowUpQueued
followUpEvt := waitForRuntimeEvent(t, runtimeCh, 2*time.Second, func(evt runtimeevents.Event) bool {
return evt.Kind == runtimeevents.KindAgentFollowUpQueued
})
payload, ok := followUpEvt.Payload.(FollowUpQueuedPayload)
if !ok {
@@ -613,59 +666,29 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
if payload.ContentLen != len("background result") {
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
}
if followUpEvt.Meta.SessionKey != "session-1" {
t.Fatalf("expected session key session-1, got %q", followUpEvt.Meta.SessionKey)
if followUpEvt.Scope.SessionKey != "session-1" {
t.Fatalf("expected session key session-1, got %q", followUpEvt.Scope.SessionKey)
}
if followUpEvt.Meta.TurnID == "" {
if followUpEvt.Scope.TurnID == "" {
t.Fatal("expected follow-up event to include turn id")
}
}
func collectEventStream(ch <-chan Event) []Event {
var events []Event
for {
select {
case evt, ok := <-ch:
if !ok {
return events
}
events = append(events, evt)
default:
return events
}
}
}
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
func receiveRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
t.Helper()
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("event stream closed before expected event arrived")
}
if match(evt) {
return evt
}
case <-timer.C:
t.Fatal("timed out waiting for expected event")
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("runtime event stream closed before expected event arrived")
}
return evt
case <-time.After(time.Second):
t.Fatal("timed out waiting for runtime event")
return runtimeevents.Event{}
}
}
func findEvent(events []Event, kind EventKind) (Event, bool) {
for _, evt := range events {
if evt.Kind == kind {
return evt, true
}
}
return Event{}, false
}
type stringError string
func (e stringError) Error() string {
+3 -260
View File
@@ -1,97 +1,8 @@
package agent
import (
"fmt"
"time"
)
// EventKind identifies a structured agent-loop event.
type EventKind uint8
const (
// EventKindTurnStart is emitted when a turn begins processing.
EventKindTurnStart EventKind = iota
// EventKindTurnEnd is emitted when a turn finishes, successfully or with an error.
EventKindTurnEnd
// EventKindLLMRequest is emitted before a provider chat request is made.
EventKindLLMRequest
// EventKindLLMDelta is emitted when a streaming provider yields a partial delta.
EventKindLLMDelta
// EventKindLLMResponse is emitted after a provider chat response is received.
EventKindLLMResponse
// EventKindLLMRetry is emitted when an LLM request is retried.
EventKindLLMRetry
// EventKindContextCompress is emitted when session history is forcibly compressed.
EventKindContextCompress
// EventKindSessionSummarize is emitted when asynchronous summarization completes.
EventKindSessionSummarize
// EventKindToolExecStart is emitted immediately before a tool executes.
EventKindToolExecStart
// EventKindToolExecEnd is emitted immediately after a tool finishes executing.
EventKindToolExecEnd
// EventKindToolExecSkipped is emitted when a queued tool call is skipped.
EventKindToolExecSkipped
// EventKindSteeringInjected is emitted when queued steering is injected into context.
EventKindSteeringInjected
// EventKindFollowUpQueued is emitted when an async tool queues a follow-up system message.
EventKindFollowUpQueued
// EventKindInterruptReceived is emitted when a soft interrupt message is accepted.
EventKindInterruptReceived
// EventKindSubTurnSpawn is emitted when a sub-turn is spawned.
EventKindSubTurnSpawn
// EventKindSubTurnEnd is emitted when a sub-turn finishes.
EventKindSubTurnEnd
// EventKindSubTurnResultDelivered is emitted when a sub-turn result is delivered.
EventKindSubTurnResultDelivered
// EventKindSubTurnOrphan is emitted when a sub-turn result cannot be delivered.
EventKindSubTurnOrphan
// EventKindError is emitted when a turn encounters an execution error.
EventKindError
eventKindCount
)
var eventKindNames = [...]string{
"turn_start",
"turn_end",
"llm_request",
"llm_delta",
"llm_response",
"llm_retry",
"context_compress",
"session_summarize",
"tool_exec_start",
"tool_exec_end",
"tool_exec_skipped",
"steering_injected",
"follow_up_queued",
"interrupt_received",
"subturn_spawn",
"subturn_end",
"subturn_result_delivered",
"subturn_orphan",
"error",
}
// String returns the stable string form of an EventKind.
func (k EventKind) String() string {
if k >= eventKindCount {
return fmt.Sprintf("event_kind(%d)", k)
}
return eventKindNames[k]
}
// Event is the structured envelope broadcast by the agent EventBus.
type Event struct {
Kind EventKind
Time time.Time
Meta EventMeta
Context *TurnContext
Payload any
}
// EventMeta contains correlation fields shared by all agent-loop events.
type EventMeta struct {
// HookMeta contains correlation fields shared by agent hook requests and
// runtime events emitted from turn processing.
type HookMeta struct {
AgentID string
TurnID string
ParentTurnID string
@@ -101,171 +12,3 @@ type EventMeta struct {
Source string
turnContext *TurnContext
}
// TurnEndStatus describes the terminal state of a turn.
type TurnEndStatus string
const (
// TurnEndStatusCompleted indicates the turn finished normally.
TurnEndStatusCompleted TurnEndStatus = "completed"
// TurnEndStatusError indicates the turn ended because of an error.
TurnEndStatusError TurnEndStatus = "error"
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
TurnEndStatusAborted TurnEndStatus = "aborted"
)
// TurnStartPayload describes the start of a turn.
type TurnStartPayload struct {
UserMessage string
MediaCount int
}
// TurnEndPayload describes the completion of a turn.
type TurnEndPayload struct {
Status TurnEndStatus
Iterations int
Duration time.Duration
FinalContentLen int
}
// LLMRequestPayload describes an outbound LLM request.
type LLMRequestPayload struct {
Model string
MessagesCount int
ToolsCount int
MaxTokens int
Temperature float64
}
// LLMResponsePayload describes an inbound LLM response.
type LLMResponsePayload struct {
ContentLen int
ToolCalls int
HasReasoning bool
}
// LLMDeltaPayload describes a streamed LLM delta.
type LLMDeltaPayload struct {
ContentDeltaLen int
ReasoningDeltaLen int
}
// LLMRetryPayload describes a retry of an LLM request.
type LLMRetryPayload struct {
Attempt int
MaxRetries int
Reason string
Error string
Backoff time.Duration
}
// ContextCompressReason identifies why emergency compression ran.
type ContextCompressReason string
const (
// ContextCompressReasonProactive indicates compression before the first LLM call.
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
// ContextCompressReasonRetry indicates compression during context-error retry handling.
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
// ContextCompressReasonSummarize indicates post-turn async summarization.
ContextCompressReasonSummarize ContextCompressReason = "summarize"
)
// ContextCompressPayload describes a forced history compression.
type ContextCompressPayload struct {
Reason ContextCompressReason
DroppedMessages int
RemainingMessages int
}
// SessionSummarizePayload describes a completed async session summarization.
type SessionSummarizePayload struct {
SummarizedMessages int
KeptMessages int
SummaryLen int
OmittedOversized bool
}
// ToolExecStartPayload describes a tool execution request.
type ToolExecStartPayload struct {
Tool string
Arguments map[string]any
}
// ToolExecEndPayload describes the outcome of a tool execution.
type ToolExecEndPayload struct {
Tool string
Duration time.Duration
ForLLMLen int
ForUserLen int
IsError bool
Async bool
}
// ToolExecSkippedPayload describes a skipped tool call.
type ToolExecSkippedPayload struct {
Tool string
Reason string
}
// SteeringInjectedPayload describes steering messages appended before the next LLM call.
type SteeringInjectedPayload struct {
Count int
TotalContentLen int
}
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
type FollowUpQueuedPayload struct {
SourceTool string
ContentLen int
}
type InterruptKind string
const (
InterruptKindSteering InterruptKind = "steering"
InterruptKindGraceful InterruptKind = "graceful"
InterruptKindHard InterruptKind = "hard_abort"
)
// InterruptReceivedPayload describes accepted turn-control input.
type InterruptReceivedPayload struct {
Kind InterruptKind
Role string
ContentLen int
QueueDepth int
HintLen int
}
// SubTurnSpawnPayload describes the creation of a child turn.
type SubTurnSpawnPayload struct {
AgentID string
Label string
ParentTurnID string
}
// SubTurnEndPayload describes the completion of a child turn.
type SubTurnEndPayload struct {
AgentID string
Status string
}
// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
type SubTurnResultDeliveredPayload struct {
TargetChannel string
TargetChatID string
ContentLen int
}
// SubTurnOrphanPayload describes a sub-turn result that could not be delivered.
type SubTurnOrphanPayload struct {
ParentTurnID string
ChildTurnID string
Reason string
}
// ErrorPayload describes an execution error inside the agent loop.
type ErrorPayload struct {
Stage string
Message string
}
+88
View File
@@ -0,0 +1,88 @@
package agent
import runtimeevents "github.com/sipeed/picoclaw/pkg/events"
func (al *AgentLoop) publishRuntimeEvent(evt runtimeevents.Event) {
if al == nil || al.runtimeEvents == nil {
return
}
al.runtimeEvents.PublishNonBlocking(evt)
}
func runtimeScopeFromHookMeta(meta HookMeta, eventCtx *TurnContext) runtimeevents.Scope {
scope := runtimeevents.Scope{
AgentID: meta.AgentID,
SessionKey: meta.SessionKey,
TurnID: meta.TurnID,
}
if eventCtx == nil || eventCtx.Inbound == nil {
return scope
}
inbound := eventCtx.Inbound
scope.Channel = inbound.Channel
scope.Account = inbound.Account
scope.ChatID = inbound.ChatID
scope.TopicID = inbound.TopicID
scope.SpaceID = inbound.SpaceID
scope.SpaceType = inbound.SpaceType
scope.ChatType = inbound.ChatType
scope.SenderID = inbound.SenderID
scope.MessageID = inbound.MessageID
return scope
}
func runtimeCorrelationFromHookMeta(meta HookMeta) runtimeevents.Correlation {
return runtimeevents.Correlation{
TraceID: meta.TracePath,
ParentTurnID: meta.ParentTurnID,
}
}
func runtimeSeverityForAgentEvent(kind runtimeevents.Kind, payload any) runtimeevents.Severity {
switch kind {
case runtimeevents.KindAgentError, runtimeevents.KindAgentSubTurnOrphan:
return runtimeevents.SeverityError
case runtimeevents.KindAgentLLMRetry,
runtimeevents.KindAgentContextCompress,
runtimeevents.KindAgentToolExecSkipped:
return runtimeevents.SeverityWarn
case runtimeevents.KindAgentTurnEnd:
payload, ok := payload.(TurnEndPayload)
if !ok {
return runtimeevents.SeverityInfo
}
switch payload.Status {
case TurnEndStatusError:
return runtimeevents.SeverityError
case TurnEndStatusAborted:
return runtimeevents.SeverityWarn
default:
return runtimeevents.SeverityInfo
}
case runtimeevents.KindAgentToolExecEnd:
payload, ok := payload.(ToolExecEndPayload)
if ok && payload.IsError {
return runtimeevents.SeverityWarn
}
return runtimeevents.SeverityInfo
default:
return runtimeevents.SeverityInfo
}
}
func runtimeAttrsFromHookMeta(meta HookMeta) map[string]any {
attrs := make(map[string]any, 2)
if meta.Source != "" {
attrs["agent_source"] = meta.Source
}
if meta.Iteration != 0 {
attrs["iteration"] = meta.Iteration
}
if len(attrs) == 0 {
return nil
}
return attrs
}
+28 -6
View File
@@ -8,6 +8,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
type hookRuntime struct {
@@ -295,10 +296,11 @@ func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error)
case "", "*", "all":
return nil, true, nil
default:
if _, ok := validKinds[kind]; !ok {
normalizedKind, ok := validKinds[kind]
if !ok {
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
}
normalized = append(normalized, kind)
normalized = append(normalized, normalizedKind)
}
}
@@ -308,10 +310,30 @@ func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error)
return normalized, true, nil
}
func validHookEventKinds() map[string]struct{} {
kinds := make(map[string]struct{}, int(eventKindCount))
for kind := EventKind(0); kind < eventKindCount; kind++ {
kinds[kind.String()] = struct{}{}
func validHookEventKinds() map[string]string {
runtimeKinds := runtimeevents.KnownKinds()
kinds := make(map[string]string, len(runtimeKinds)*2)
for _, kind := range runtimeKinds {
kinds[kind.String()] = kind.String()
}
kinds["turn_start"] = runtimeevents.KindAgentTurnStart.String()
kinds["turn_end"] = runtimeevents.KindAgentTurnEnd.String()
kinds["llm_request"] = runtimeevents.KindAgentLLMRequest.String()
kinds["llm_delta"] = runtimeevents.KindAgentLLMDelta.String()
kinds["llm_response"] = runtimeevents.KindAgentLLMResponse.String()
kinds["llm_retry"] = runtimeevents.KindAgentLLMRetry.String()
kinds["context_compress"] = runtimeevents.KindAgentContextCompress.String()
kinds["session_summarize"] = runtimeevents.KindAgentSessionSummarize.String()
kinds["tool_exec_start"] = runtimeevents.KindAgentToolExecStart.String()
kinds["tool_exec_end"] = runtimeevents.KindAgentToolExecEnd.String()
kinds["tool_exec_skipped"] = runtimeevents.KindAgentToolExecSkipped.String()
kinds["steering_injected"] = runtimeevents.KindAgentSteeringInjected.String()
kinds["follow_up_queued"] = runtimeevents.KindAgentFollowUpQueued.String()
kinds["interrupt_received"] = runtimeevents.KindAgentInterruptReceived.String()
kinds["subturn_spawn"] = runtimeevents.KindAgentSubTurnSpawn.String()
kinds["subturn_end"] = runtimeevents.KindAgentSubTurnEnd.String()
kinds["subturn_result_delivered"] = runtimeevents.KindAgentSubTurnResultDelivered.String()
kinds["subturn_orphan"] = runtimeevents.KindAgentSubTurnOrphan.String()
kinds["error"] = runtimeevents.KindAgentError.String()
return kinds
}
+22 -1
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"path/filepath"
"slices"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -155,7 +156,27 @@ func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T)
t.Fatalf("expected process model, got %q", lastModel)
}
waitForFileContains(t, eventLog, "turn_end")
waitForFileContains(t, eventLog, "agent.turn.end")
}
func TestProcessHookObserveKindsFromConfigAcceptsRuntimeNames(t *testing.T) {
kinds, enabled, err := processHookObserveKindsFromConfig([]string{
"tool_exec_start",
"agent.tool.exec_end",
"gateway.ready",
"mcp.server.failed",
})
if err != nil {
t.Fatalf("processHookObserveKindsFromConfig failed: %v", err)
}
if !enabled {
t.Fatal("expected observe to be enabled")
}
want := []string{"agent.tool.exec_start", "agent.tool.exec_end", "gateway.ready", "mcp.server.failed"}
if !slices.Equal(kinds, want) {
t.Fatalf("observe kinds = %v, want %v", kinds, want)
}
}
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
+3 -2
View File
@@ -12,6 +12,7 @@ import (
"sync/atomic"
"time"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/isolation"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -183,7 +184,7 @@ func (ph *ProcessHook) Close() error {
return ph.closeErr
}
func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
func (ph *ProcessHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
if ph == nil || !ph.opts.Observe {
return nil
}
@@ -192,7 +193,7 @@ func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
return nil
}
}
return ph.notify(ctx, "hook.event", evt)
return ph.notify(ctx, "hook.runtime_event", evt)
}
func (ph *ProcessHook) BeforeLLM(
+14 -9
View File
@@ -13,6 +13,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/isolation"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -66,7 +67,7 @@ func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
t.Fatalf("expected process model, got %q", lastModel)
}
waitForFileContains(t, eventLog, "turn_end")
waitForFileContains(t, eventLog, "agent.turn.end")
}
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
@@ -146,8 +147,13 @@ func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
t.Fatalf("MountProcessHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentToolExecSkipped,
)
defer closeRuntimeEvents()
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
@@ -167,8 +173,8 @@ func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
t.Fatalf("expected %q, got %q", expected, resp)
}
events := collectEventStream(sub.C)
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
events := collectRuntimeEventStream(runtimeCh)
skippedEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecSkipped)
if !ok {
t.Fatal("expected tool skipped event")
}
@@ -350,12 +356,11 @@ func runProcessHookHelper() error {
}
if msg.ID == 0 {
if msg.Method == "hook.event" && eventLog != "" {
if msg.Method == "hook.runtime_event" && eventLog != "" {
var evt map[string]any
if err := json.Unmarshal(msg.Params, &evt); err == nil {
if rawKind, ok := evt["Kind"].(float64); ok {
kind := EventKind(rawKind)
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
if kind, ok := evt["kind"].(string); ok {
_ = os.WriteFile(eventLog, []byte(kind+"\n"), 0o644)
}
}
}
+56 -36
View File
@@ -9,6 +9,7 @@ import (
"sync"
"time"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -71,8 +72,8 @@ func NamedHook(name string, hook any) HookRegistration {
}
}
type EventObserver interface {
OnEvent(ctx context.Context, evt Event) error
type RuntimeEventObserver interface {
OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error
}
type LLMInterceptor interface {
@@ -90,7 +91,7 @@ type ToolApprover interface {
}
type LLMHookRequest struct {
Meta EventMeta `json:"meta"`
Meta HookMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Model string `json:"model"`
Messages []providers.Message `json:"messages,omitempty"`
@@ -104,7 +105,7 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Meta = cloneHookMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Messages = cloneProviderMessages(r.Messages)
cloned.Tools = cloneToolDefinitions(r.Tools)
@@ -113,7 +114,7 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
}
type LLMHookResponse struct {
Meta EventMeta `json:"meta"`
Meta HookMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Model string `json:"model"`
Response *providers.LLMResponse `json:"response,omitempty"`
@@ -124,14 +125,14 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Meta = cloneHookMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Response = cloneLLMResponse(r.Response)
return &cloned
}
type ToolCallHookRequest struct {
Meta EventMeta `json:"meta"`
Meta HookMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
@@ -145,7 +146,7 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Meta = cloneHookMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.HookResult = cloneToolResult(r.HookResult)
@@ -153,7 +154,7 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
}
type ToolApprovalRequest struct {
Meta EventMeta `json:"meta"`
Meta HookMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
@@ -164,14 +165,14 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Meta = cloneHookMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolResultHookResponse struct {
Meta EventMeta `json:"meta"`
Meta HookMeta `json:"meta"`
Context *TurnContext `json:"context,omitempty"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
@@ -184,7 +185,7 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
return nil
}
cloned := *r
cloned.Meta = cloneEventMeta(r.Meta)
cloned.Meta = cloneHookMeta(r.Meta)
cloned.Context = cloneTurnContext(r.Context)
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.Result = cloneToolResult(r.Result)
@@ -192,7 +193,7 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
}
type HookManager struct {
eventBus *EventBus
runtimeEvents runtimeevents.EventChannel
observerTimeout time.Duration
interceptorTimeout time.Duration
approvalTimeout time.Duration
@@ -201,28 +202,39 @@ type HookManager struct {
hooks map[string]HookRegistration
ordered []HookRegistration
sub EventSubscription
done chan struct{}
closeOnce sync.Once
runtimeSub runtimeevents.Subscription
runtimeDone chan struct{}
closeOnce sync.Once
}
func NewHookManager(eventBus *EventBus) *HookManager {
func NewHookManager(runtimeEvents runtimeevents.EventChannel) *HookManager {
hm := &HookManager{
eventBus: eventBus,
runtimeEvents: runtimeEvents,
observerTimeout: defaultHookObserverTimeout,
interceptorTimeout: defaultHookInterceptorTimeout,
approvalTimeout: defaultHookApprovalTimeout,
hooks: make(map[string]HookRegistration),
done: make(chan struct{}),
runtimeDone: make(chan struct{}),
}
if eventBus == nil {
close(hm.done)
return hm
if runtimeEvents != nil {
sub, ch, err := runtimeEvents.SubscribeChan(context.Background(), runtimeevents.SubscribeOptions{
Name: "hook-manager-observer",
Buffer: hookObserverBufferSize,
})
if err != nil {
logger.WarnCF("hooks", "Failed to subscribe runtime events for hooks", map[string]any{
"error": err.Error(),
})
close(hm.runtimeDone)
} else {
hm.runtimeSub = sub
go hm.dispatchRuntimeEvents(ch)
}
} else {
close(hm.runtimeDone)
}
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
go hm.dispatchEvents()
return hm
}
@@ -232,10 +244,14 @@ func (hm *HookManager) Close() {
}
hm.closeOnce.Do(func() {
if hm.eventBus != nil {
hm.eventBus.Unsubscribe(hm.sub.ID)
if hm.runtimeSub != nil {
if err := hm.runtimeSub.Close(); err != nil {
logger.WarnCF("hooks", "Failed to close runtime event hook subscription", map[string]any{
"error": err.Error(),
})
}
}
<-hm.done
<-hm.runtimeDone
hm.closeAllHooks()
})
}
@@ -292,16 +308,16 @@ func (hm *HookManager) Unmount(name string) {
hm.rebuildOrdered()
}
func (hm *HookManager) dispatchEvents() {
defer close(hm.done)
func (hm *HookManager) dispatchRuntimeEvents(ch <-chan runtimeevents.Event) {
defer close(hm.runtimeDone)
for evt := range hm.sub.C {
for evt := range ch {
for _, reg := range hm.snapshotHooks() {
observer, ok := reg.Hook.(EventObserver)
observer, ok := reg.Hook.(RuntimeEventObserver)
if !ok {
continue
}
hm.runObserver(reg.Name, observer, evt)
hm.runRuntimeObserver(reg.Name, observer, evt)
}
}
}
@@ -581,26 +597,30 @@ func (hm *HookManager) closeAllHooks() {
hm.ordered = nil
}
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
func (hm *HookManager) runRuntimeObserver(
name string,
observer RuntimeEventObserver,
evt runtimeevents.Event,
) {
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- observer.OnEvent(ctx, evt)
done <- observer.OnRuntimeEvent(ctx, evt)
}()
select {
case err := <-done:
if err != nil {
logger.WarnCF("hooks", "Event observer failed", map[string]any{
logger.WarnCF("hooks", "Runtime event observer failed", map[string]any{
"hook": name,
"event": evt.Kind.String(),
"error": err.Error(),
})
}
case <-ctx.Done():
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
logger.WarnCF("hooks", "Runtime event observer timed out", map[string]any{
"hook": name,
"event": evt.Kind.String(),
"timeout_ms": hm.observerTimeout.Milliseconds(),
+137 -51
View File
@@ -12,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
@@ -111,14 +112,14 @@ func (p *llmHookTestProvider) GetDefaultModel() string {
}
type llmObserverHook struct {
eventCh chan Event
eventCh chan runtimeevents.Event
lastInbound *bus.InboundContext
lastRoute *routing.ResolvedRoute
lastScope *session.SessionScope
}
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
if evt.Kind == EventKindTurnEnd {
func (h *llmObserverHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
if evt.Kind == runtimeevents.KindAgentTurnEnd {
select {
case h.eventCh <- evt:
default:
@@ -150,6 +151,20 @@ func (h *llmObserverHook) AfterLLM(
return next, HookDecision{Action: HookActionModify}, nil
}
type dualRuntimeObserverHook struct {
runtimeCh chan runtimeevents.Event
}
func (h *dualRuntimeObserverHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
if evt.Kind == runtimeevents.KindAgentTurnEnd {
select {
case h.runtimeCh <- evt:
default:
}
}
return nil
}
type llmSystemRewriteHook struct{}
func (h *llmSystemRewriteHook) BeforeLLM(
@@ -417,7 +432,7 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
hook := &llmObserverHook{eventCh: make(chan runtimeevents.Event, 1)}
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
@@ -481,30 +496,80 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
select {
case evt := <-hook.eventCh:
if evt.Kind != EventKindTurnEnd {
if evt.Kind != runtimeevents.KindAgentTurnEnd {
t.Fatalf("expected turn end event, got %v", evt.Kind)
}
if evt.Context == nil || evt.Context.Inbound == nil {
t.Fatal("expected observer event to carry inbound context")
}
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
t.Fatalf("expected observer event to carry route context, got %+v", evt.Context.Route)
}
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "hook-user" {
t.Fatalf("expected observer event to carry session scope, got %+v", evt.Context.Scope)
if evt.Scope.AgentID != "main" ||
evt.Scope.SessionKey != "session-1" ||
evt.Scope.Channel != "cli" ||
evt.Scope.ChatID != "direct" ||
evt.Scope.SenderID != "hook-user" {
t.Fatalf("runtime observer scope = %+v", evt.Scope)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for hook observer event")
}
}
func TestAgentLoop_Hooks_RuntimeObserverReceivesEvents(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &dualRuntimeObserverHook{
runtimeCh: make(chan runtimeevents.Event, 1),
}
if err := al.MountHook(NamedHook("runtime-observer", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
InboundContext: &bus.InboundContext{
Channel: "cli",
Account: "default",
ChatID: "direct",
ChatType: "direct",
SenderID: "hook-user",
MessageID: "msg-1",
},
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "provider content" {
t.Fatalf("expected provider content, got %q", resp)
}
select {
case evt := <-hook.runtimeCh:
if evt.Kind != runtimeevents.KindAgentTurnEnd {
t.Fatalf("runtime observer kind = %q", evt.Kind)
}
if evt.Scope.SessionKey != "session-1" ||
evt.Scope.Channel != "cli" ||
evt.Scope.ChatID != "direct" ||
evt.Scope.MessageID != "msg-1" {
t.Fatalf("runtime observer scope = %+v", evt.Scope)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for runtime observer event")
}
}
func TestAgentLoop_BtwCommand_UsesLLMHooks(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
useTestSideQuestionProvider(al, provider)
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
hook := &llmObserverHook{eventCh: make(chan runtimeevents.Event, 1)}
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
@@ -800,8 +865,13 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentToolExecSkipped,
)
defer closeRuntimeEvents()
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
@@ -820,8 +890,8 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
t.Fatalf("expected %q, got %q", expected, resp)
}
events := collectEventStream(sub.C)
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
events := collectRuntimeEventStream(runtimeCh)
skippedEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecSkipped)
if !ok {
t.Fatal("expected tool skipped event")
}
@@ -876,8 +946,13 @@ func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentToolExecEnd,
)
defer closeRuntimeEvents()
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
@@ -899,8 +974,8 @@ func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
}
// Verify event stream has ToolExecEnd, not actual tool execution
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
events := collectRuntimeEventStream(runtimeCh)
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
if !ok {
t.Fatal("expected tool exec end event")
}
@@ -1065,8 +1140,13 @@ func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
sendErr: errors.New("channel unavailable"),
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentToolExecEnd,
)
defer closeRuntimeEvents()
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-media-err",
@@ -1081,8 +1161,8 @@ func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
t.Fatalf("runAgentLoop failed: %v", err)
}
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
events := collectRuntimeEventStream(runtimeCh)
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
if !ok {
t.Fatal("expected ToolExecEnd event")
}
@@ -1120,8 +1200,13 @@ func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentToolExecEnd,
)
defer closeRuntimeEvents()
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-bus-fallback",
@@ -1136,8 +1221,8 @@ func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
t.Fatalf("runAgentLoop failed: %v", err)
}
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
events := collectRuntimeEventStream(runtimeCh)
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
if !ok {
t.Fatal("expected ToolExecEnd event")
}
@@ -1282,8 +1367,13 @@ func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
32,
runtimeevents.KindAgentToolExecSkipped,
)
defer closeRuntimeEvents()
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
@@ -1322,9 +1412,9 @@ func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
t.Fatal("timeout waiting for result")
}
events := collectEventStream(sub.C)
events := collectRuntimeEventStream(runtimeCh)
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
skippedEvts := filterRuntimeEvents(events, runtimeevents.KindAgentToolExecSkipped)
if len(skippedEvts) < 1 {
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
}
@@ -1362,8 +1452,14 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
32,
runtimeevents.KindAgentToolExecEnd,
runtimeevents.KindAgentToolExecSkipped,
)
defer closeRuntimeEvents()
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
@@ -1383,14 +1479,14 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
resultCh <- result{resp: resp, err: err}
}()
collectedEvents := make([]Event, 0, 8)
collectedEvents := make([]runtimeevents.Event, 0, 8)
steered := false
deadline := time.After(3 * time.Second)
for !steered {
select {
case evt := <-sub.C:
case evt := <-runtimeCh:
collectedEvents = append(collectedEvents, evt)
if evt.Kind != EventKindToolExecEnd {
if evt.Kind != runtimeevents.KindAgentToolExecEnd {
continue
}
payload, ok := evt.Payload.(ToolExecEndPayload)
@@ -1413,9 +1509,9 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
t.Fatal("timeout waiting for result")
}
events := append(collectedEvents, collectEventStream(sub.C)...)
events := append(collectedEvents, collectRuntimeEventStream(runtimeCh)...)
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
skippedEvts := filterRuntimeEvents(events, runtimeevents.KindAgentToolExecSkipped)
if len(skippedEvts) < 1 {
t.Fatal("expected at least one ToolExecSkipped event after steering")
}
@@ -1480,13 +1576,3 @@ func TestCloneStringAnyMap_EmptyMapReturnsNonNil(t *testing.T) {
}
})
}
func filterEvents(events []Event, kind EventKind) []Event {
var result []Event
for _, evt := range events {
if evt.Kind == kind {
result = append(result, evt)
}
}
return result
}
+7
View File
@@ -44,4 +44,11 @@ type ChannelManager interface {
// SendPlaceholder sends a placeholder message (e.g., for audio transcription).
SendPlaceholder(ctx context.Context, channel, chatID string) bool
// DismissToolFeedback clears any tracked tool feedback animation for the
// given channel/chat. Call this when a turn ends without a final response
// (e.g., ResponseHandled tools) to avoid orphaned animation goroutines.
// outboundCtx carries topic/thread info needed for channels that use
// scoped tracker keys (e.g., Telegram forum topics); may be nil.
DismissToolFeedback(ctx context.Context, channel, chatID string, outboundCtx *bus.InboundContext)
}
+7
View File
@@ -56,5 +56,12 @@ func isVisionUnsupportedError(err error) bool {
return true
}
// DeepSeek and other strict providers reject the image_url field at the
// JSON schema level with an "unknown variant" error rather than a semantic
// "not supported" message.
if strings.Contains(msg, "unknown variant") && strings.Contains(msg, "image_url") {
return true
}
return false
}
+13 -9
View File
@@ -10,6 +10,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -72,7 +73,7 @@ toolLoop:
})
al.emitEvent(
EventKindToolExecStart,
runtimeevents.KindAgentToolExecStart,
ts.eventMeta("runTurn", "turn.tool.start"),
ToolExecStartPayload{
Tool: toolName,
@@ -191,7 +192,7 @@ toolLoop:
}
al.emitEvent(
EventKindToolExecEnd,
runtimeevents.KindAgentToolExecEnd,
ts.eventMeta("runTurn", "turn.tool.end"),
ToolExecEndPayload{
Tool: toolName,
@@ -237,7 +238,7 @@ toolLoop:
for j := i + 1; j < len(normalizedToolCalls); j++ {
skippedTC := normalizedToolCalls[j]
al.emitEvent(
EventKindToolExecSkipped,
runtimeevents.KindAgentToolExecSkipped,
ts.eventMeta("runTurn", "turn.tool.skipped"),
ToolExecSkippedPayload{
Tool: skippedTC.Name,
@@ -284,7 +285,7 @@ toolLoop:
exec.allResponsesHandled = false
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
al.emitEvent(
EventKindToolExecSkipped,
runtimeevents.KindAgentToolExecSkipped,
ts.eventMeta("runTurn", "turn.tool.skipped"),
ToolExecSkippedPayload{
Tool: toolName,
@@ -323,7 +324,7 @@ toolLoop:
exec.allResponsesHandled = false
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
al.emitEvent(
EventKindToolExecSkipped,
runtimeevents.KindAgentToolExecSkipped,
ts.eventMeta("runTurn", "turn.tool.skipped"),
ToolExecSkippedPayload{
Tool: toolName,
@@ -353,7 +354,7 @@ toolLoop:
"iteration": iteration,
})
al.emitEvent(
EventKindToolExecStart,
runtimeevents.KindAgentToolExecStart,
ts.eventMeta("runTurn", "turn.tool.start"),
ToolExecStartPayload{
Tool: toolName,
@@ -401,7 +402,7 @@ toolLoop:
"channel": ts.channel,
})
al.emitEvent(
EventKindFollowUpQueued,
runtimeevents.KindAgentFollowUpQueued,
ts.scope.meta(iteration, "runTurn", "turn.follow_up.queued"),
FollowUpQueuedPayload{
SourceTool: asyncToolName,
@@ -567,7 +568,7 @@ toolLoop:
toolResultMsg.Media = append(toolResultMsg.Media, toolResult.Media...)
}
al.emitEvent(
EventKindToolExecEnd,
runtimeevents.KindAgentToolExecEnd,
ts.eventMeta("runTurn", "turn.tool.end"),
ToolExecEndPayload{
Tool: toolName,
@@ -612,7 +613,7 @@ toolLoop:
for j := i + 1; j < len(normalizedToolCalls); j++ {
skippedTC := normalizedToolCalls[j]
al.emitEvent(
EventKindToolExecSkipped,
runtimeevents.KindAgentToolExecSkipped,
ts.eventMeta("runTurn", "turn.tool.skipped"),
ToolExecSkippedPayload{
Tool: skippedTC.Name,
@@ -704,6 +705,9 @@ toolLoop:
}
ts.setPhase(TurnPhaseCompleted)
ts.setFinalContent("")
if al.channelManager != nil && ts.channel != "" {
al.channelManager.DismissToolFeedback(ctx, ts.channel, ts.chatID, ts.opts.InboundContext)
}
logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM",
map[string]any{
"agent_id": ts.agent.ID,
+2 -1
View File
@@ -6,6 +6,7 @@ import (
"context"
"github.com/sipeed/picoclaw/pkg/bus"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -50,7 +51,7 @@ func (p *Pipeline) Finalize(
ts.ingestMessage(turnCtx, al, finalMsg)
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
al.emitEvent(
EventKindError,
runtimeevents.KindAgentError,
ts.eventMeta("runTurn", "turn.error"),
ErrorPayload{
Stage: "session_save",
+54 -8
View File
@@ -11,6 +11,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/constants"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -113,7 +114,7 @@ func (p *Pipeline) CallLLM(
}
al.emitEvent(
EventKindLLMRequest,
runtimeevents.KindAgentLLMRequest,
ts.eventMeta("runTurn", "turn.llm.request"),
LLMRequestPayload{
Model: exec.llmModel,
@@ -184,7 +185,14 @@ func (p *Pipeline) CallLLM(
// Retry loop
var err error
maxRetries := 2
maxRetries := p.Cfg.Agents.Defaults.MaxLLMRetries
if maxRetries <= 0 {
maxRetries = 2
}
backoffSecs := p.Cfg.Agents.Defaults.LLMRetryBackoffSecs
if backoffSecs <= 0 {
backoffSecs = 2
}
for retry := 0; retry <= maxRetries; retry++ {
exec.response, err = callLLM(exec.callMessages, exec.providerToolDefs)
if err == nil {
@@ -199,7 +207,7 @@ func (p *Pipeline) CallLLM(
// Retry without media if vision is unsupported
if hasMediaRefs(exec.callMessages) && isVisionUnsupportedError(err) && retry < maxRetries {
al.emitEvent(
EventKindLLMRetry,
runtimeevents.KindAgentLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: retry + 1,
@@ -232,6 +240,15 @@ func (p *Pipeline) CallLLM(
strings.Contains(errMsg, "timed out") ||
strings.Contains(errMsg, "timeout exceeded")
isNetworkError := !isTimeoutError && (strings.Contains(errMsg, "connection reset") ||
strings.Contains(errMsg, "connection refused") ||
strings.Contains(errMsg, "broken pipe") ||
strings.Contains(errMsg, "no such host") ||
strings.Contains(errMsg, "network is unreachable") ||
strings.Contains(errMsg, "read tcp") ||
strings.Contains(errMsg, "write tcp") ||
strings.Contains(errMsg, "eof"))
isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
strings.Contains(errMsg, "context window") ||
strings.Contains(errMsg, "context_window") ||
@@ -244,9 +261,9 @@ func (p *Pipeline) CallLLM(
strings.Contains(errMsg, "request too large"))
if isTimeoutError && retry < maxRetries {
backoff := time.Duration(retry+1) * 5 * time.Second
backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second
al.emitEvent(
EventKindLLMRetry,
runtimeevents.KindAgentLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: retry + 1,
@@ -272,9 +289,38 @@ func (p *Pipeline) CallLLM(
continue
}
if isNetworkError && retry < maxRetries {
backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second
al.emitEvent(
runtimeevents.KindAgentLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: retry + 1,
MaxRetries: maxRetries,
Reason: "network",
Error: err.Error(),
Backoff: backoff,
},
)
logger.WarnCF("agent", "Network error, retrying after backoff", map[string]any{
"error": err.Error(),
"retry": retry,
"backoff": backoff.String(),
})
if sleepErr := sleepWithContext(turnCtx, backoff); sleepErr != nil {
if ts.hardAbortRequested() {
_ = ts.requestHardAbort()
return ControlBreak, nil
}
err = sleepErr
break
}
continue
}
if isContextError && retry < maxRetries && !ts.opts.NoHistory {
al.emitEvent(
EventKindLLMRetry,
runtimeevents.KindAgentLLMRetry,
ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: retry + 1,
@@ -333,7 +379,7 @@ func (p *Pipeline) CallLLM(
if err != nil {
al.emitEvent(
EventKindError,
runtimeevents.KindAgentError,
ts.eventMeta("runTurn", "turn.error"),
ErrorPayload{
Stage: "llm",
@@ -397,7 +443,7 @@ func (p *Pipeline) CallLLM(
)
}
al.emitEvent(
EventKindLLMResponse,
runtimeevents.KindAgentLLMResponse,
ts.eventMeta("runTurn", "turn.llm.response"),
LLMResponsePayload{
ContentLen: len(exec.response.Content),
+408
View File
@@ -0,0 +1,408 @@
package agent
import (
"context"
"fmt"
"path"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
)
const (
runtimeEventLoggerBuffer = 256
runtimeEventLoggerDrainTimeout = 2 * time.Second
)
type runtimeEventLogger struct {
mu sync.RWMutex
cfg config.EventLoggingConfig
}
func (al *AgentLoop) refreshRuntimeEventLogger(cfg *config.Config) {
if al == nil {
return
}
logCfg := config.EffectiveEventLoggingConfig(cfg)
al.runtimeEventLogMu.Lock()
if !logCfg.Enabled {
oldSub := al.runtimeEventLogSub
al.runtimeEventLogger = nil
al.runtimeEventLogSub = nil
al.runtimeEventLogMu.Unlock()
closeRuntimeEventLoggerSubscription(oldSub)
return
}
if al.runtimeEventLogger != nil && al.runtimeEventLogSub != nil {
al.runtimeEventLogger.updateConfig(logCfg)
al.runtimeEventLogMu.Unlock()
return
}
al.runtimeEventLogMu.Unlock()
eventLogger := newRuntimeEventLoggerFromConfig(logCfg)
sub, err := eventLogger.subscribe(context.Background(), al.runtimeEvents)
if err != nil {
logger.WarnCF("events", "Failed to subscribe runtime event logger", map[string]any{"error": err.Error()})
return
}
al.runtimeEventLogMu.Lock()
oldSub := al.runtimeEventLogSub
al.runtimeEventLogger = eventLogger
al.runtimeEventLogSub = sub
al.runtimeEventLogMu.Unlock()
closeRuntimeEventLoggerSubscription(oldSub)
}
func (al *AgentLoop) closeRuntimeEventLogger() {
if al == nil {
return
}
al.runtimeEventLogMu.Lock()
oldSub := al.runtimeEventLogSub
al.runtimeEventLogger = nil
al.runtimeEventLogSub = nil
al.runtimeEventLogMu.Unlock()
closeRuntimeEventLoggerSubscription(oldSub)
}
func closeRuntimeEventLoggerSubscription(sub runtimeevents.Subscription) {
if sub == nil {
return
}
if err := sub.Close(); err != nil {
logger.WarnCF("events", "Failed to close runtime event logger subscription", map[string]any{
"error": err.Error(),
})
}
timer := time.NewTimer(runtimeEventLoggerDrainTimeout)
defer timer.Stop()
select {
case <-sub.Done():
case <-timer.C:
logger.WarnCF("events", "Timed out waiting for runtime event logger to drain", map[string]any{
"timeout": runtimeEventLoggerDrainTimeout.String(),
})
}
}
func newRuntimeEventLogger(cfg *config.Config) *runtimeEventLogger {
logCfg := config.EffectiveEventLoggingConfig(cfg)
if !logCfg.Enabled {
return nil
}
return newRuntimeEventLoggerFromConfig(logCfg)
}
func newRuntimeEventLoggerFromConfig(logCfg config.EventLoggingConfig) *runtimeEventLogger {
return &runtimeEventLogger{cfg: logCfg}
}
func (l *runtimeEventLogger) updateConfig(cfg config.EventLoggingConfig) {
if l == nil {
return
}
l.mu.Lock()
l.cfg = cfg
l.mu.Unlock()
}
func (l *runtimeEventLogger) configSnapshot() config.EventLoggingConfig {
if l == nil {
return config.EventLoggingConfig{}
}
l.mu.RLock()
defer l.mu.RUnlock()
return l.cfg
}
func (l *runtimeEventLogger) subscribe(
ctx context.Context,
eventBus runtimeevents.Bus,
) (runtimeevents.Subscription, error) {
if l == nil || eventBus == nil {
return nil, nil
}
return eventBus.Channel().Subscribe(ctx, runtimeevents.SubscribeOptions{
Name: "runtime-event-logger",
Buffer: runtimeEventLoggerBuffer,
Concurrency: runtimeevents.Locked,
Backpressure: runtimeevents.DropNewest,
PanicPolicy: runtimeevents.RecoverAndLog,
}, l.handle)
}
func (l *runtimeEventLogger) handle(_ context.Context, evt runtimeevents.Event) error {
if l == nil || !l.shouldLog(evt) {
return nil
}
fields := runtimeEventLogFields(evt)
if l.configSnapshot().IncludePayload && evt.Payload != nil {
fields["payload"] = evt.Payload
}
logRuntimeEvent(evt, fields)
return nil
}
func (l *runtimeEventLogger) shouldLog(evt runtimeevents.Event) bool {
if l == nil {
return false
}
cfg := l.configSnapshot()
if !cfg.Enabled {
return false
}
if runtimeEventSeverityRank(evt.Severity) < runtimeEventSeverityRank(parseRuntimeEventSeverity(cfg.MinSeverity)) {
return false
}
kind := evt.Kind.String()
if !matchAnyRuntimeEventPattern(cfg.Include, kind, true) {
return false
}
return !matchAnyRuntimeEventPattern(cfg.Exclude, kind, false)
}
func logRuntimeEvent(evt runtimeevents.Event, fields map[string]any) {
message := fmt.Sprintf("Runtime event: %s", evt.Kind.String())
switch normalizeRuntimeEventSeverity(evt.Severity) {
case runtimeevents.SeverityDebug:
logger.DebugCF("events", message, fields)
case runtimeevents.SeverityWarn:
logger.WarnCF("events", message, fields)
case runtimeevents.SeverityError:
logger.ErrorCF("events", message, fields)
default:
logger.InfoCF("events", message, fields)
}
}
func runtimeEventLogFields(evt runtimeevents.Event) map[string]any {
fields := map[string]any{
"event_id": evt.ID,
"event_kind": evt.Kind.String(),
"severity": string(normalizeRuntimeEventSeverity(evt.Severity)),
}
if !evt.Time.IsZero() {
fields["event_time"] = evt.Time.Format(time.RFC3339Nano)
}
appendRuntimeEventSourceFields(fields, evt.Source)
appendRuntimeEventScopeFields(fields, evt.Scope)
appendRuntimeEventCorrelationFields(fields, evt.Correlation)
appendRuntimeEventAttrs(fields, evt.Attrs)
appendRuntimeEventPayloadSummary(fields, evt.Payload)
return fields
}
func appendRuntimeEventSourceFields(fields map[string]any, source runtimeevents.Source) {
if source.Component != "" {
fields["source_component"] = source.Component
}
if source.Name != "" {
fields["source_name"] = source.Name
}
}
func appendRuntimeEventScopeFields(fields map[string]any, scope runtimeevents.Scope) {
setStringField(fields, "runtime_id", scope.RuntimeID)
setStringField(fields, "agent_id", scope.AgentID)
setStringField(fields, "session_key", scope.SessionKey)
setStringField(fields, "turn_id", scope.TurnID)
setStringField(fields, "channel", scope.Channel)
setStringField(fields, "account", scope.Account)
setStringField(fields, "chat_id", scope.ChatID)
setStringField(fields, "topic_id", scope.TopicID)
setStringField(fields, "space_id", scope.SpaceID)
setStringField(fields, "space_type", scope.SpaceType)
setStringField(fields, "chat_type", scope.ChatType)
setStringField(fields, "sender_id", scope.SenderID)
setStringField(fields, "message_id", scope.MessageID)
}
func appendRuntimeEventCorrelationFields(fields map[string]any, correlation runtimeevents.Correlation) {
setStringField(fields, "trace_id", correlation.TraceID)
setStringField(fields, "parent_turn_id", correlation.ParentTurnID)
setStringField(fields, "request_id", correlation.RequestID)
setStringField(fields, "reply_to_id", correlation.ReplyToID)
}
func appendRuntimeEventAttrs(fields map[string]any, attrs map[string]any) {
for key, value := range attrs {
if key == "" || value == nil {
continue
}
if _, exists := fields[key]; exists {
fields["attr_"+key] = value
continue
}
fields[key] = value
}
}
func appendRuntimeEventPayloadSummary(fields map[string]any, payload any) {
switch payload := payload.(type) {
case TurnStartPayload:
fields["user_len"] = len(payload.UserMessage)
fields["media_count"] = payload.MediaCount
case TurnEndPayload:
fields["status"] = payload.Status
fields["iterations_total"] = payload.Iterations
fields["duration_ms"] = payload.Duration.Milliseconds()
fields["final_len"] = payload.FinalContentLen
case LLMRequestPayload:
fields["model"] = payload.Model
fields["messages"] = payload.MessagesCount
fields["tools"] = payload.ToolsCount
fields["max_tokens"] = payload.MaxTokens
case LLMDeltaPayload:
fields["content_delta_len"] = payload.ContentDeltaLen
fields["reasoning_delta_len"] = payload.ReasoningDeltaLen
case LLMResponsePayload:
fields["content_len"] = payload.ContentLen
fields["tool_calls"] = payload.ToolCalls
fields["has_reasoning"] = payload.HasReasoning
case LLMRetryPayload:
fields["attempt"] = payload.Attempt
fields["max_retries"] = payload.MaxRetries
fields["reason"] = payload.Reason
fields["error"] = payload.Error
fields["backoff_ms"] = payload.Backoff.Milliseconds()
case ContextCompressPayload:
fields["reason"] = payload.Reason
fields["dropped_messages"] = payload.DroppedMessages
fields["remaining_messages"] = payload.RemainingMessages
case SessionSummarizePayload:
fields["summarized_messages"] = payload.SummarizedMessages
fields["kept_messages"] = payload.KeptMessages
fields["summary_len"] = payload.SummaryLen
fields["omitted_oversized"] = payload.OmittedOversized
case ToolExecStartPayload:
fields["tool"] = payload.Tool
fields["args_count"] = len(payload.Arguments)
case ToolExecEndPayload:
fields["tool"] = payload.Tool
fields["duration_ms"] = payload.Duration.Milliseconds()
fields["for_llm_len"] = payload.ForLLMLen
fields["for_user_len"] = payload.ForUserLen
fields["is_error"] = payload.IsError
fields["async"] = payload.Async
case ToolExecSkippedPayload:
fields["tool"] = payload.Tool
fields["reason"] = payload.Reason
case SteeringInjectedPayload:
fields["count"] = payload.Count
fields["total_content_len"] = payload.TotalContentLen
case FollowUpQueuedPayload:
fields["source_tool"] = payload.SourceTool
fields["content_len"] = payload.ContentLen
case InterruptReceivedPayload:
fields["interrupt_kind"] = payload.Kind
fields["role"] = payload.Role
fields["content_len"] = payload.ContentLen
fields["queue_depth"] = payload.QueueDepth
fields["hint_len"] = payload.HintLen
case SubTurnSpawnPayload:
fields["child_agent_id"] = payload.AgentID
fields["label"] = payload.Label
case SubTurnEndPayload:
fields["child_agent_id"] = payload.AgentID
fields["status"] = payload.Status
case SubTurnResultDeliveredPayload:
fields["target_channel"] = payload.TargetChannel
fields["target_chat_id"] = payload.TargetChatID
fields["content_len"] = payload.ContentLen
case SubTurnOrphanPayload:
fields["parent_turn_id"] = payload.ParentTurnID
fields["child_turn_id"] = payload.ChildTurnID
fields["reason"] = payload.Reason
case ErrorPayload:
fields["stage"] = payload.Stage
fields["error"] = payload.Message
}
}
func setStringField(fields map[string]any, key, value string) {
if value != "" {
fields[key] = value
}
}
func matchAnyRuntimeEventPattern(patterns []string, kind string, emptyMatches bool) bool {
if len(patterns) == 0 {
return emptyMatches
}
for _, pattern := range patterns {
if matchRuntimeEventPattern(pattern, kind) {
return true
}
}
return false
}
func matchRuntimeEventPattern(pattern, kind string) bool {
pattern = strings.TrimSpace(pattern)
if pattern == "" {
return false
}
if pattern == "*" {
return true
}
if strings.HasSuffix(pattern, ".*") {
return strings.HasPrefix(kind, strings.TrimSuffix(pattern, "*"))
}
matched, err := path.Match(pattern, kind)
if err == nil {
return matched
}
return pattern == kind
}
func parseRuntimeEventSeverity(severity string) runtimeevents.Severity {
switch strings.ToLower(strings.TrimSpace(severity)) {
case "debug":
return runtimeevents.SeverityDebug
case "warn", "warning":
return runtimeevents.SeverityWarn
case "error":
return runtimeevents.SeverityError
default:
return runtimeevents.SeverityInfo
}
}
func normalizeRuntimeEventSeverity(severity runtimeevents.Severity) runtimeevents.Severity {
switch severity {
case runtimeevents.SeverityDebug,
runtimeevents.SeverityInfo,
runtimeevents.SeverityWarn,
runtimeevents.SeverityError:
return severity
default:
return runtimeevents.SeverityInfo
}
}
func runtimeEventSeverityRank(severity runtimeevents.Severity) int {
switch normalizeRuntimeEventSeverity(severity) {
case runtimeevents.SeverityDebug:
return 0
case runtimeevents.SeverityInfo:
return 1
case runtimeevents.SeverityWarn:
return 2
case runtimeevents.SeverityError:
return 3
default:
return 1
}
}
+259
View File
@@ -0,0 +1,259 @@
package agent
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
func TestRuntimeEventLoggerFiltering(t *testing.T) {
cfg := config.DefaultConfig()
eventLogger := newRuntimeEventLogger(cfg)
if eventLogger == nil {
t.Fatal("default runtime event logger is nil")
}
if !eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindAgentTurnStart,
Severity: runtimeevents.SeverityInfo,
}) {
t.Fatal("default config should log agent events")
}
if eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindChannelLifecycleStarted,
Severity: runtimeevents.SeverityInfo,
}) {
t.Fatal("default config should not log non-agent events")
}
cfg.Events.Logging.Include = []string{"*"}
cfg.Events.Logging.Exclude = []string{"mcp.*"}
eventLogger = newRuntimeEventLogger(cfg)
if !eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindGatewayReady,
Severity: runtimeevents.SeverityInfo,
}) {
t.Fatal("include * should log gateway events")
}
if eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindMCPServerConnected,
Severity: runtimeevents.SeverityInfo,
}) {
t.Fatal("exclude mcp.* should suppress MCP events")
}
cfg.Events.Logging.Exclude = nil
cfg.Events.Logging.MinSeverity = "warn"
eventLogger = newRuntimeEventLogger(cfg)
if eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindGatewayReady,
Severity: runtimeevents.SeverityInfo,
}) {
t.Fatal("min severity warn should suppress info events")
}
if !eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindGatewayReloadFailed,
Severity: runtimeevents.SeverityError,
}) {
t.Fatal("min severity warn should allow error events")
}
cfg.Events.Logging.Enabled = false
if newRuntimeEventLogger(cfg) != nil {
t.Fatal("disabled config should not create runtime event logger")
}
}
func TestRuntimeEventLogFieldsSummarizeAgentPayload(t *testing.T) {
fields := runtimeEventLogFields(runtimeevents.Event{
ID: "evt-test",
Kind: runtimeevents.KindAgentToolExecStart,
Severity: runtimeevents.SeverityInfo,
Source: runtimeevents.Source{
Component: "agent",
Name: "main",
},
Scope: runtimeevents.Scope{
AgentID: "main",
SessionKey: "session-1",
TurnID: "turn-1",
},
Payload: ToolExecStartPayload{
Tool: "exec",
Arguments: map[string]any{
"secret": "should-not-be-logged-by-default",
},
},
})
if fields["event_id"] != "evt-test" || fields["source_component"] != "agent" {
t.Fatalf("missing common event fields: %#v", fields)
}
if fields["tool"] != "exec" || fields["args_count"] != 1 {
t.Fatalf("missing safe agent payload summary fields: %#v", fields)
}
if _, ok := fields["payload"]; ok {
t.Fatalf("raw payload should not be included by runtimeEventLogFields: %#v", fields)
}
}
func TestRuntimeEventLogFieldsIncludeSafeAttrs(t *testing.T) {
fields := runtimeEventLogFields(runtimeevents.Event{
ID: "evt-gateway",
Kind: runtimeevents.KindGatewayReady,
Severity: runtimeevents.SeverityInfo,
Attrs: map[string]any{
"duration_ms": 42,
"error": "startup failed",
"event_kind": "conflict",
},
})
if fields["duration_ms"] != 42 || fields["error"] != "startup failed" {
t.Fatalf("missing safe attrs: %#v", fields)
}
if fields["event_kind"] != runtimeevents.KindGatewayReady.String() {
t.Fatalf("event_kind overwritten by attrs: %#v", fields)
}
if fields["attr_event_kind"] != "conflict" {
t.Fatalf("conflicting attr not preserved with prefix: %#v", fields)
}
if _, ok := fields["payload"]; ok {
t.Fatalf("raw payload should not be included by runtimeEventLogFields: %#v", fields)
}
}
func runtimeEventLoggerStateForTest(
al *AgentLoop,
) (*runtimeEventLogger, runtimeevents.Subscription) {
al.runtimeEventLogMu.RLock()
defer al.runtimeEventLogMu.RUnlock()
return al.runtimeEventLogger, al.runtimeEventLogSub
}
func TestReloadProviderAndConfigRefreshesRuntimeEventLogger(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = t.TempDir()
cfg.Events.Logging.Include = []string{"agent.*"}
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
defer al.Close()
eventLogger, logSub := runtimeEventLoggerStateForTest(al)
if eventLogger == nil || logSub == nil {
t.Fatal("expected initial runtime event logger subscription")
}
if eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindGatewayReloadCompleted,
Severity: runtimeevents.SeverityInfo,
}) {
t.Fatal("initial agent-only logging should not log gateway reload events")
}
reloaded := config.DefaultConfig()
reloaded.Agents.Defaults.Workspace = cfg.Agents.Defaults.Workspace
reloaded.Events.Logging.Include = []string{"gateway.*"}
if err := al.ReloadProviderAndConfig(context.Background(), &mockProvider{}, reloaded); err != nil {
t.Fatalf("ReloadProviderAndConfig() error = %v", err)
}
eventLogger, logSub = runtimeEventLoggerStateForTest(al)
if eventLogger == nil || logSub == nil {
t.Fatal("expected runtime event logger subscription after reload")
}
if !eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindGatewayReloadCompleted,
Severity: runtimeevents.SeverityInfo,
}) {
t.Fatal("reloaded gateway logging should log gateway reload events")
}
if eventLogger.shouldLog(runtimeevents.Event{
Kind: runtimeevents.KindAgentTurnStart,
Severity: runtimeevents.SeverityInfo,
}) {
t.Fatal("reloaded gateway-only logging should not log agent events")
}
disabled := config.DefaultConfig()
disabled.Agents.Defaults.Workspace = cfg.Agents.Defaults.Workspace
disabled.Events.Logging.Enabled = false
if err := al.ReloadProviderAndConfig(context.Background(), &mockProvider{}, disabled); err != nil {
t.Fatalf("ReloadProviderAndConfig() with disabled logging error = %v", err)
}
eventLogger, logSub = runtimeEventLoggerStateForTest(al)
if eventLogger != nil || logSub != nil {
t.Fatal("expected runtime event logger to be disabled after reload")
}
}
func TestCloseRuntimeEventLoggerSubscriptionWaitsForDrain(t *testing.T) {
eventBus := runtimeevents.NewBus()
defer func() {
if err := eventBus.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
}()
var handled atomic.Uint64
firstStarted := make(chan struct{})
releaseFirst := make(chan struct{})
sub, err := eventBus.Channel().Subscribe(
context.Background(),
runtimeevents.SubscribeOptions{
Name: "runtime-event-logger",
Buffer: 2,
Concurrency: runtimeevents.Locked,
},
func(context.Context, runtimeevents.Event) error {
if handled.Add(1) == 1 {
close(firstStarted)
<-releaseFirst
}
return nil
},
)
if err != nil {
t.Fatalf("Subscribe failed: %v", err)
}
first := eventBus.Publish(context.Background(), runtimeevents.Event{Kind: runtimeevents.Kind("test.first")})
if first.Delivered != 1 {
t.Fatalf("first Publish = %+v, want one delivered event", first)
}
select {
case <-firstStarted:
case <-time.After(time.Second):
t.Fatal("timed out waiting for first handler to start")
}
second := eventBus.Publish(context.Background(), runtimeevents.Event{Kind: runtimeevents.Kind("test.second")})
if second.Delivered != 1 {
t.Fatalf("second Publish = %+v, want one delivered event", second)
}
closeReturned := make(chan struct{})
go func() {
closeRuntimeEventLoggerSubscription(sub)
close(closeReturned)
}()
select {
case <-closeReturned:
t.Fatal("runtime event logger close returned before buffered events drained")
case <-time.After(50 * time.Millisecond):
}
close(releaseFirst)
select {
case <-closeReturned:
case <-time.After(time.Second):
t.Fatal("timed out waiting for runtime event logger close to return")
}
if got := handled.Load(); got != 2 {
t.Fatalf("handled = %d, want 2", got)
}
}
+103
View File
@@ -0,0 +1,103 @@
package agent
import (
"testing"
"time"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
func subscribeRuntimeEventsForTest(
t *testing.T,
al *AgentLoop,
buffer int,
kinds ...runtimeevents.Kind,
) (<-chan runtimeevents.Event, func()) {
t.Helper()
if al == nil {
t.Fatal("agent loop is nil")
}
channel := al.RuntimeEvents()
if channel == nil {
t.Fatal("runtime event channel is nil")
}
if len(kinds) > 0 {
channel = channel.OfKind(kinds...)
}
sub, ch, err := channel.SubscribeChan(
t.Context(),
runtimeevents.SubscribeOptions{Name: "agent-runtime-test", Buffer: buffer},
)
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
return ch, func() {
if err := sub.Close(); err != nil {
t.Errorf("runtime subscription close failed: %v", err)
}
}
}
func waitForRuntimeEvent(
t *testing.T,
ch <-chan runtimeevents.Event,
timeout time.Duration,
match func(runtimeevents.Event) bool,
) runtimeevents.Event {
t.Helper()
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("runtime event stream closed before expected event arrived")
}
if match(evt) {
return evt
}
case <-timer.C:
t.Fatal("timed out waiting for expected runtime event")
}
}
}
func collectRuntimeEventStream(ch <-chan runtimeevents.Event) []runtimeevents.Event {
var events []runtimeevents.Event
for {
select {
case evt, ok := <-ch:
if !ok {
return events
}
events = append(events, evt)
default:
return events
}
}
}
func findRuntimeEvent(
events []runtimeevents.Event,
kind runtimeevents.Kind,
) (runtimeevents.Event, bool) {
for _, evt := range events {
if evt.Kind == kind {
return evt, true
}
}
return runtimeevents.Event{}, false
}
func filterRuntimeEvents(events []runtimeevents.Event, kind runtimeevents.Kind) []runtimeevents.Event {
var filtered []runtimeevents.Event
for _, evt := range events {
if evt.Kind == kind {
filtered = append(filtered, evt)
}
}
return filtered
}
+28 -4
View File
@@ -8,6 +8,7 @@ import (
"sync"
"github.com/sipeed/picoclaw/pkg/bus"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
@@ -155,6 +156,18 @@ func (sq *steeringQueue) lenScope(scope string) int {
return len(sq.queues[normalizeSteeringScope(scope)])
}
func (sq *steeringQueue) clearScope(scope string) int {
sq.mu.Lock()
defer sq.mu.Unlock()
scope = normalizeSteeringScope(scope)
count := len(sq.queues[scope])
if count > 0 {
delete(sq.queues, scope)
}
return count
}
// setMode updates the steering mode.
func (sq *steeringQueue) setMode(mode SteeringMode) {
sq.mu.Lock()
@@ -206,7 +219,7 @@ func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers
"scope": normalizeSteeringScope(scope),
})
meta := EventMeta{
meta := HookMeta{
Source: "Steer",
TracePath: "turn.interrupt.received",
}
@@ -230,7 +243,7 @@ func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers
}
al.emitEvent(
EventKindInterruptReceived,
runtimeevents.KindAgentInterruptReceived,
meta,
InterruptReceivedPayload{
Kind: InterruptKindSteering,
@@ -289,6 +302,13 @@ func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
return al.steering.lenScope(scope)
}
func (al *AgentLoop) clearSteeringMessagesForScope(scope string) int {
if al.steering == nil {
return 0
}
return al.steering.clearScope(scope)
}
func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
@@ -410,7 +430,7 @@ func (al *AgentLoop) InterruptGraceful(hint string) error {
}
al.emitEvent(
EventKindInterruptReceived,
runtimeevents.KindAgentInterruptReceived,
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindGraceful,
@@ -438,7 +458,7 @@ func (al *AgentLoop) InterruptHard() error {
}
al.emitEvent(
EventKindInterruptReceived,
runtimeevents.KindAgentInterruptReceived,
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindHard,
@@ -510,6 +530,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
"initial_history_length": ts.initialHistoryLength,
})
// Cancel the active provider/tool turn contexts immediately so long-running
// execution stops as soon as possible on the root turn.
_ = ts.requestHardAbort()
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
// from adding more messages to the session. This prevents race conditions
// where rollback happens while children are still writing.
+354 -13
View File
@@ -14,6 +14,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -839,6 +840,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
}
}
func TestAgentLoop_Run_PendingStopStillContinuesQueuedFollowUp(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxParallelTurns: 1,
},
},
}
msgBus := bus.NewMessageBus()
provider := &lateSteeringProvider{
firstCallStarted: make(chan struct{}),
releaseFirstCall: make(chan struct{}),
}
al := NewAgentLoop(cfg, msgBus, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
defer func() {
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run() error = %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}()
blockerSessionKey := session.BuildOpaqueSessionKey("agent:main:test:blocker")
targetSessionKey := session.BuildOpaqueSessionKey("agent:main:test:target")
blockerCtx := bus.InboundContext{
Channel: "test",
ChatID: "blocker-chat",
ChatType: "direct",
SenderID: "user1",
}
targetCtx := bus.InboundContext{
Channel: "test",
ChatID: "target-chat",
ChatType: "direct",
SenderID: "user1",
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: blockerCtx,
Content: "block worker pool",
SessionKey: blockerSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(blocker) error = %v", err)
}
select {
case <-provider.firstCallStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for blocker turn to start")
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "skip this turn",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(target start) error = %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for {
ts := al.getActiveTurnState(targetSessionKey)
if ts != nil && strings.HasPrefix(ts.turnID, pendingTurnPrefix) {
break
}
if time.Now().After(deadline) {
t.Fatal("timeout waiting for pending placeholder")
}
time.Sleep(10 * time.Millisecond)
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "/stop",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(/stop) error = %v", err)
}
deadline = time.Now().Add(2 * time.Second)
stopSeen := false
for !stopSeen {
select {
case outbound := <-msgBus.OutboundChan():
if outbound.ChatID == "target-chat" && outbound.Content == "Task stopped. Current task was canceled." {
stopSeen = true
}
case <-time.After(10 * time.Millisecond):
if time.Now().After(deadline) {
t.Fatal("timeout waiting for /stop reply")
}
}
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: targetCtx,
Content: "run this instead",
SessionKey: targetSessionKey,
}); err != nil {
t.Fatalf("PublishInbound(follow-up) error = %v", err)
}
deadline = time.Now().Add(2 * time.Second)
for al.pendingSteeringCountForScope(targetSessionKey) == 0 {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for follow-up to enter scoped steering queue")
}
time.Sleep(10 * time.Millisecond)
}
close(provider.releaseFirstCall)
deadline = time.Now().Add(5 * time.Second)
followUpSeen := false
for !followUpSeen {
select {
case outbound := <-msgBus.OutboundChan():
if outbound.ChatID == "target-chat" && outbound.Content == "continued response" {
followUpSeen = true
}
case <-time.After(10 * time.Millisecond):
if time.Now().After(deadline) {
t.Fatal("timeout waiting for queued follow-up continuation")
}
}
}
deadline = time.Now().Add(2 * time.Second)
for {
if al.GetActiveTurnBySession(targetSessionKey) == nil &&
al.pendingSteeringCountForScope(targetSessionKey) == 0 {
break
}
if time.Now().After(deadline) {
t.Fatal("timeout waiting for target session to go idle")
}
time.Sleep(10 * time.Millisecond)
}
provider.mu.Lock()
calls := provider.calls
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls (blocker + continuation), got %d", calls)
}
foundFollowUp := false
for _, msg := range secondMessages {
if msg.Role == "user" && msg.Content == "run this instead" {
foundFollowUp = true
}
if msg.Role == "user" && msg.Content == "skip this turn" {
t.Fatalf("unexpected canceled message in continuation context: %q", msg.Content)
}
}
if !foundFollowUp {
t.Fatal("expected queued follow-up to be processed after pending stop")
}
}
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
@@ -1051,16 +1237,16 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
foundResolvedMedia := false
for _, msg := range msgs {
if msg.Role != "user" || msg.Content != "describe this image" || len(msg.Media) != 1 {
if msg.Role != "user" || !strings.Contains(msg.Content, "describe this image") {
continue
}
if strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
if strings.Contains(msg.Content, "[image:") {
foundResolvedMedia = true
break
}
}
if !foundResolvedMedia {
t.Fatal("expected continue path to inject steering media into the provider request")
t.Fatal("expected continue path to inject image path tag into the provider request")
}
defaultAgent := al.registry.GetDefaultAgent()
@@ -1134,8 +1320,14 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
al.RegisterTool(tool2)
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
32,
runtimeevents.KindAgentInterruptReceived,
runtimeevents.KindAgentTurnEnd,
)
defer closeRuntimeEvents()
type result struct {
resp string
@@ -1222,8 +1414,8 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
}
events := collectEventStream(sub.C)
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
events := collectRuntimeEventStream(runtimeCh)
interruptEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
@@ -1235,7 +1427,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
}
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
turnEndEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentTurnEnd)
if !ok {
t.Fatal("expected turn end event")
}
@@ -1299,8 +1491,14 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
}
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentInterruptReceived,
runtimeevents.KindAgentTurnEnd,
)
defer closeRuntimeEvents()
type result struct {
resp string
@@ -1353,8 +1551,8 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
}
events := collectEventStream(sub.C)
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
events := collectRuntimeEventStream(runtimeCh)
interruptEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
@@ -1366,7 +1564,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
}
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
turnEndEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentTurnEnd)
if !ok {
t.Fatal("expected turn end event")
}
@@ -1379,6 +1577,149 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
}
}
func TestAgentLoop_StopCommand_AbortsActiveTurnAndClearsQueuedSteering(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "cancel_tool",
Function: &providers.FunctionCall{
Name: "cancel_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "should not continue",
}
al := NewAgentLoop(cfg, msgBus, provider)
started := make(chan struct{})
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
defer func() {
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run() error = %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}()
baseMsg := testInboundMessage(bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
SessionKey: sessionKey,
})
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "do work",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(start) error = %v", err)
}
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for interruptible tool to start")
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "follow up after cancel",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(follow-up) error = %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for al.pendingSteeringCountForScope(sessionKey) == 0 {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for follow-up message to enter steering queue")
}
time.Sleep(10 * time.Millisecond)
}
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
Context: baseMsg.Context,
Content: "/stop",
SessionKey: sessionKey,
}); err != nil {
t.Fatalf("PublishInbound(/stop) error = %v", err)
}
select {
case outbound := <-msgBus.OutboundChan():
want := "Task stopped. \"do work\" was canceled."
if outbound.Content != want {
t.Fatalf("stop reply = %q, want %q", outbound.Content, want)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /stop reply")
}
deadline = time.Now().Add(5 * time.Second)
for al.GetActiveTurnBySession(sessionKey) != nil {
if time.Now().After(deadline) {
t.Fatal("timeout waiting for active turn to stop")
}
time.Sleep(10 * time.Millisecond)
}
if got := al.pendingSteeringCountForScope(sessionKey); got != 0 {
t.Fatalf("expected cleared steering queue, got %d pending message(s)", got)
}
select {
case outbound := <-msgBus.OutboundChan():
t.Fatalf("unexpected outbound after stop: %q", outbound.Content)
case <-time.After(300 * time.Millisecond):
}
provider.mu.Lock()
calls := provider.calls
provider.mu.Unlock()
if calls != 1 {
t.Fatalf("expected provider to stop before follow-up turn, got %d calls", calls)
}
}
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
+35 -19
View File
@@ -8,6 +8,7 @@ import (
"sync/atomic"
"time"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/providers/messageutil"
@@ -173,7 +174,10 @@ type SubTurnConfig struct {
// Used by team tool to enforce token limits across all team members.
InitialTokenBudget *atomic.Int64
// Can be extended with temperature, topP, etc.
// TargetAgentID, when set, runs the sub-turn as the specified agent.
// The target agent's workspace, model, tools, and system prompt are used
// instead of the caller's. If empty, the sub-turn runs as the parent agent.
TargetAgentID string
}
// ====================== Context Keys ======================
@@ -231,6 +235,7 @@ func (s *AgentLoopSpawner) SpawnSubTurn(
Critical: cfg.Critical,
Timeout: cfg.Timeout,
MaxContextRunes: cfg.MaxContextRunes,
TargetAgentID: cfg.TargetAgentID,
}
return spawnSubTurn(ctx, s.al, parentTS, agentCfg)
@@ -313,8 +318,9 @@ func spawnSubTurn(
return nil, ErrDepthLimitExceeded
}
// 2. Config validation
if cfg.Model == "" {
// 2. Config validation: Model is required unless TargetAgentID is set
// (the target agent provides its own model).
if cfg.Model == "" && cfg.TargetAgentID == "" {
return nil, ErrInvalidSubTurnConfig
}
@@ -332,12 +338,22 @@ func spawnSubTurn(
childID := al.generateSubTurnID()
// Get the agent instance from parent, falling back to the default agent.
// Wrap it in a shallow copy that uses an ephemeral (in-memory only) session store
// so that child turns never pollute or persist to the parent's session history.
baseAgent := parentTS.agent
if baseAgent == nil {
baseAgent = al.registry.GetDefaultAgent()
// Resolve the agent instance for the child turn.
// When TargetAgentID is set, look up that agent from the registry so the
// child runs with the target's workspace, model, tools, and system prompt.
// Otherwise fall back to the parent's agent (existing behavior).
var baseAgent *AgentInstance
if cfg.TargetAgentID != "" {
var ok bool
baseAgent, ok = al.registry.GetAgent(cfg.TargetAgentID)
if !ok {
return nil, fmt.Errorf("target agent %q not found in registry", cfg.TargetAgentID)
}
} else {
baseAgent = parentTS.agent
if baseAgent == nil {
baseAgent = al.registry.GetDefaultAgent()
}
}
if baseAgent == nil {
return nil, errors.New("parent turnState has no agent instance")
@@ -422,7 +438,7 @@ func spawnSubTurn(
parentTS.mu.Unlock()
// 6. Emit Spawn event
al.emitEvent(EventKindSubTurnSpawn,
al.emitEvent(runtimeevents.KindAgentSubTurnSpawn,
childTS.eventMeta("spawnSubTurn", "subturn.spawn"),
SubTurnSpawnPayload{
AgentID: childTS.agentID,
@@ -453,7 +469,7 @@ func spawnSubTurn(
if err != nil {
status = "error"
}
al.emitEvent(EventKindSubTurnEnd,
al.emitEvent(runtimeevents.KindAgentSubTurnEnd,
childTS.eventMeta("spawnSubTurn", "subturn.end"),
SubTurnEndPayload{
AgentID: childTS.agentID,
@@ -504,16 +520,16 @@ func spawnSubTurn(
//
// Delivery behavior:
// - If parent turn is still running: attempts to deliver to pendingResults channel
// - If channel is full: emits SubTurnOrphanResultEvent (result is lost from channel but tracked)
// - If parent turn has finished: emits SubTurnOrphanResultEvent (late arrival)
// - If channel is full: emits agent.subturn.orphan (result is lost from channel but tracked)
// - If parent turn has finished: emits agent.subturn.orphan (late arrival)
//
// Thread safety:
// - Reads parent state under lock, then releases lock before channel send
// - Small race window exists but is acceptable (worst case: result becomes orphan)
//
// Event emissions:
// - SubTurnResultDeliveredEvent: successful delivery to channel
// - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full)
// - agent.subturn.result_delivered: successful delivery to channel
// - agent.subturn.orphan: delivery failed (parent finished or channel full)
func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, result *tools.ToolResult) {
// Let GC clean up the pendingResults channel; parent Finish will no longer close it.
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
@@ -526,7 +542,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
"recover": r,
})
if result != nil && al != nil {
al.emitEvent(EventKindSubTurnOrphan,
al.emitEvent(runtimeevents.KindAgentSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "panic"},
)
@@ -541,7 +557,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
// If parent turn has already finished, treat this as an orphan result
if isFinished || resultChan == nil {
if result != nil && al != nil {
al.emitEvent(EventKindSubTurnOrphan,
al.emitEvent(runtimeevents.KindAgentSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "parent_finished"},
)
@@ -557,7 +573,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
case resultChan <- result:
// Successfully delivered
if al != nil {
al.emitEvent(EventKindSubTurnResultDelivered,
al.emitEvent(runtimeevents.KindAgentSubTurnResultDelivered,
parentTS.eventMeta("deliverSubTurnResult", "subturn.result_delivered"),
SubTurnResultDeliveredPayload{ContentLen: len(result.ForLLM)},
)
@@ -571,7 +587,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
})
if result != nil && al != nil {
al.emitEvent(
EventKindSubTurnOrphan,
runtimeevents.KindAgentSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{
ParentTurnID: parentTS.turnID,
+258 -27
View File
@@ -4,12 +4,16 @@ import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -22,30 +26,38 @@ const (
// ====================== Test Helper: Event Collector ======================
type eventCollector struct {
mu sync.Mutex
events []Event
events []runtimeevents.Event
}
func newEventCollector(t *testing.T, al *AgentLoop) (*eventCollector, func()) {
t.Helper()
c := &eventCollector{}
sub := al.SubscribeEvents(16)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentSubTurnSpawn,
runtimeevents.KindAgentSubTurnEnd,
runtimeevents.KindAgentSubTurnResultDelivered,
runtimeevents.KindAgentSubTurnOrphan,
)
done := make(chan struct{})
go func() {
defer close(done)
for evt := range sub.C {
for evt := range runtimeCh {
c.mu.Lock()
c.events = append(c.events, evt)
c.mu.Unlock()
}
}()
cleanup := func() {
al.UnsubscribeEvents(sub.ID)
closeRuntimeEvents()
<-done
}
return c, cleanup
}
func (c *eventCollector) hasEventOfKind(kind EventKind) bool {
func (c *eventCollector) hasEventOfKind(kind runtimeevents.Kind) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, e := range c.events {
@@ -131,7 +143,7 @@ func TestSpawnSubTurn(t *testing.T) {
agent: al.registry.GetDefaultAgent(),
}
// Subscribe to real EventBus to capture events
// Subscribe to runtime events to capture sub-turn lifecycle.
collector, collectCleanup := newEventCollector(t, al)
defer collectCleanup()
@@ -158,12 +170,12 @@ func TestSpawnSubTurn(t *testing.T) {
// Verify event emission
time.Sleep(10 * time.Millisecond) // let event goroutine flush
if tt.wantSpawn {
if !collector.hasEventOfKind(EventKindSubTurnSpawn) {
if !collector.hasEventOfKind(runtimeevents.KindAgentSubTurnSpawn) {
t.Error("SubTurnSpawnEvent not emitted")
}
}
if tt.wantEnd {
if !collector.hasEventOfKind(EventKindSubTurnEnd) {
if !collector.hasEventOfKind(runtimeevents.KindAgentSubTurnEnd) {
t.Error("SubTurnEndEvent not emitted")
}
}
@@ -316,8 +328,8 @@ func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
time.Sleep(10 * time.Millisecond) // let event goroutine flush
// Verify Orphan event is emitted
if !collector.hasEventOfKind(EventKindSubTurnOrphan) {
t.Error("SubTurnOrphanResultEvent not emitted for finished parent")
if !collector.hasEventOfKind(runtimeevents.KindAgentSubTurnOrphan) {
t.Error("agent.subturn.orphan not emitted for finished parent")
}
// Verify history is NOT polluted
@@ -591,12 +603,16 @@ func TestNestedSubTurnHierarchy(t *testing.T) {
var spawnedTurns []turnInfo
var mu sync.Mutex
// Subscribe to real EventBus to capture spawn events
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
16,
runtimeevents.KindAgentSubTurnSpawn,
)
defer closeRuntimeEvents()
go func() {
for evt := range sub.C {
if evt.Kind == EventKindSubTurnSpawn {
for evt := range runtimeCh {
if evt.Kind == runtimeevents.KindAgentSubTurnSpawn {
p, _ := evt.Payload.(SubTurnSpawnPayload)
mu.Lock()
spawnedTurns = append(spawnedTurns, turnInfo{
@@ -879,7 +895,7 @@ func TestSpawnSubTurn_PanicRecovery(t *testing.T) {
time.Sleep(10 * time.Millisecond) // let event goroutine flush
// SubTurnEndEvent should still be emitted
if !collector.hasEventOfKind(EventKindSubTurnEnd) {
if !collector.hasEventOfKind(runtimeevents.KindAgentSubTurnEnd) {
t.Error("SubTurnEndEvent not emitted after panic")
}
@@ -1229,18 +1245,23 @@ func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t) //nolint:dogsled
defer cleanup()
// Collect events via real EventBus
var mu sync.Mutex
var deliveredCount, orphanCount int
sub := al.SubscribeEvents(64)
defer al.UnsubscribeEvents(sub.ID)
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
64,
runtimeevents.KindAgentSubTurnResultDelivered,
runtimeevents.KindAgentSubTurnOrphan,
)
defer closeRuntimeEvents()
go func() {
for evt := range sub.C {
for evt := range runtimeCh {
mu.Lock()
switch evt.Kind {
case EventKindSubTurnResultDelivered:
case runtimeevents.KindAgentSubTurnResultDelivered:
deliveredCount++
case EventKindSubTurnOrphan:
case runtimeevents.KindAgentSubTurnOrphan:
orphanCount++
}
mu.Unlock()
@@ -1795,13 +1816,20 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
provider := &slowMockProvider{delay: 5 * time.Second} // SubTurn takes 5 seconds
al := NewAgentLoop(cfg, msgBus, provider)
// Capture events via real EventBus
var mu sync.Mutex
var events []Event
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
var events []runtimeevents.Event
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
t,
al,
32,
runtimeevents.KindAgentSubTurnSpawn,
runtimeevents.KindAgentSubTurnEnd,
runtimeevents.KindAgentSubTurnResultDelivered,
runtimeevents.KindAgentSubTurnOrphan,
)
defer closeRuntimeEvents()
go func() {
for evt := range sub.C {
for evt := range runtimeCh {
mu.Lock()
events = append(events, evt)
mu.Unlock()
@@ -2097,3 +2125,206 @@ func TestSubTurn_IndependentContext(t *testing.T) {
t.Log("✓ SubTurn completed successfully (independent context)")
}
}
// ====================== TargetAgentID Tests ======================
// modelRecordingProvider captures the model passed to Chat for test assertions.
type modelRecordingProvider struct {
mu sync.Mutex
lastModel string
}
func (rp *modelRecordingProvider) Chat(
_ context.Context,
_ []providers.Message,
_ []providers.ToolDefinition,
model string,
_ map[string]any,
) (*providers.LLMResponse, error) {
rp.mu.Lock()
rp.lastModel = model
rp.mu.Unlock()
return &providers.LLMResponse{Content: "Mock response"}, nil
}
func (rp *modelRecordingProvider) GetDefaultModel() string { return "mock-model" }
func (rp *modelRecordingProvider) getLastModel() string {
rp.mu.Lock()
defer rp.mu.Unlock()
return rp.lastModel
}
// newMultiAgentLoop creates an AgentLoop with two named agents for testing
// cross-agent delegation via TargetAgentID.
func newMultiAgentLoop(t *testing.T, provider providers.LLMProvider) (*AgentLoop, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "multiagent-test-*")
if err != nil {
t.Fatalf("create temp dir: %v", err)
}
alphaDir := filepath.Join(tmpDir, "alpha")
betaDir := filepath.Join(tmpDir, "beta")
os.MkdirAll(alphaDir, 0o755)
os.MkdirAll(betaDir, 0o755)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "default-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
List: []config.AgentConfig{
{
ID: "alpha",
Workspace: alphaDir,
Model: &config.AgentModelConfig{Primary: "model-alpha"},
},
{
ID: "beta",
Workspace: betaDir,
Model: &config.AgentModelConfig{Primary: "model-beta"},
},
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
return al, func() { os.RemoveAll(tmpDir) }
}
func TestSpawnSubTurn_TargetAgentID_UsesTargetAgent(t *testing.T) {
rp := &modelRecordingProvider{}
al, cleanup := newMultiAgentLoop(t, rp)
defer cleanup()
alphaAgent, ok := al.registry.GetAgent("alpha")
if !ok {
t.Fatal("alpha agent not in registry")
}
// Parent is alpha, target is beta
parent := &turnState{
ctx: context.Background(),
turnID: "parent-alpha",
depth: 0,
childTurnIDs: []string{},
pendingResults: make(chan *tools.ToolResult, 4),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
session: &ephemeralSessionStore{},
agent: alphaAgent,
}
result, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
TargetAgentID: "beta",
SystemPrompt: "task for beta",
})
if err != nil {
t.Fatalf("spawnSubTurn failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
// The recording provider captures the model passed to Chat().
// If TargetAgentID works correctly, the child turn should have
// used beta's model, not alpha's.
if got := rp.getLastModel(); got != "model-beta" {
t.Errorf("child turn used model %q, want %q", got, "model-beta")
}
}
func TestSpawnSubTurn_TargetAgentID_NotFound(t *testing.T) {
al, cleanup := newMultiAgentLoop(t, &mockProvider{})
defer cleanup()
alphaAgent, _ := al.registry.GetAgent("alpha")
parent := &turnState{
ctx: context.Background(),
turnID: "parent-alpha",
depth: 0,
childTurnIDs: []string{},
pendingResults: make(chan *tools.ToolResult, 4),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
session: &ephemeralSessionStore{},
agent: alphaAgent,
}
_, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
TargetAgentID: "nonexistent",
SystemPrompt: "task",
})
if err == nil {
t.Fatal("expected error for nonexistent agent")
}
if !strings.Contains(err.Error(), "not found") {
t.Errorf("error should mention 'not found', got: %v", err)
}
}
func TestSpawnSubTurn_TargetAgentID_EmptyModelAccepted(t *testing.T) {
al, cleanup := newMultiAgentLoop(t, &mockProvider{})
defer cleanup()
alphaAgent, _ := al.registry.GetAgent("alpha")
parent := &turnState{
ctx: context.Background(),
turnID: "parent-alpha",
depth: 0,
childTurnIDs: []string{},
pendingResults: make(chan *tools.ToolResult, 4),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
session: &ephemeralSessionStore{},
agent: alphaAgent,
}
// Model is empty but TargetAgentID is set — should NOT fail validation
result, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
Model: "", // intentionally empty
TargetAgentID: "beta",
SystemPrompt: "task for beta",
})
if err != nil {
t.Fatalf("should accept empty Model when TargetAgentID is set, got: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
}
func TestDelegateToolNotRegistered_SingleAgent(t *testing.T) {
// Single-agent setup: delegate should not be registered
al, _, _, provider, cleanup := newTestAgentLoop(t)
_ = provider
defer cleanup()
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("default agent should exist")
}
if _, has := agent.Tools.Get("delegate"); has {
t.Error("delegate tool should not be registered in single-agent setup")
}
}
func TestDelegateToolRegistered_MultiAgent(t *testing.T) {
al, cleanup := newMultiAgentLoop(t, &mockProvider{})
defer cleanup()
// Both agents should have the delegate tool
for _, id := range []string{"alpha", "beta"} {
agent, ok := al.registry.GetAgent(id)
if !ok {
t.Fatalf("agent %q not found", id)
}
if _, has := agent.Tools.Get("delegate"); !has {
t.Errorf("agent %q should have delegate tool in multi-agent setup", id)
}
}
}
+1 -1
View File
@@ -61,7 +61,7 @@ func cloneStringMap(src map[string]string) map[string]string {
return cloned
}
func cloneEventMeta(meta EventMeta) EventMeta {
func cloneHookMeta(meta HookMeta) HookMeta {
meta.turnContext = cloneTurnContext(meta.turnContext)
return meta
}
+18 -8
View File
@@ -9,6 +9,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -25,10 +26,14 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
al.registerActiveTurn(ts)
defer al.clearActiveTurn(ts)
if al.takePendingStop(ts.sessionKey) {
_ = ts.requestHardAbort()
}
turnStatus := TurnEndStatusCompleted
defer func() {
al.emitEvent(
EventKindTurnEnd,
runtimeevents.KindAgentTurnEnd,
ts.eventMeta("runTurn", "turn.end"),
TurnEndPayload{
Status: turnStatus,
@@ -39,8 +44,13 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
)
}()
if ts.hardAbortRequested() {
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
}
al.emitEvent(
EventKindTurnStart,
runtimeevents.KindAgentTurnStart,
ts.eventMeta("runTurn", "turn.start"),
TurnStartPayload{
UserMessage: ts.userMessage,
@@ -140,7 +150,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
})
}
al.emitEvent(
EventKindSteeringInjected,
runtimeevents.KindAgentSteeringInjected,
ts.eventMeta("runTurn", "turn.steering.injected"),
SteeringInjectedPayload{
Count: len(pendingMessages),
@@ -249,7 +259,7 @@ func (al *AgentLoop) abortTurn(ts *turnState) (turnResult, error) {
if !ts.opts.NoHistory {
if err := ts.restoreSession(ts.agent); err != nil {
al.emitEvent(
EventKindError,
runtimeevents.KindAgentError,
ts.eventMeta("abortTurn", "turn.error"),
ErrorPayload{
Stage: "session_restore",
@@ -414,7 +424,7 @@ func (al *AgentLoop) askSideQuestion(
llmModel := activeModel
if al.hooks != nil {
llmReq, decision := al.hooks.BeforeLLM(ctx, &LLMHookRequest{
Meta: EventMeta{
Meta: HookMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.request",
turnContext: cloneTurnContext(turnCtx),
@@ -494,8 +504,8 @@ func (al *AgentLoop) askSideQuestion(
resp, err = callSideLLM(messages)
if err != nil && hasMediaRefs(messages) && isVisionUnsupportedError(err) {
al.emitEvent(
EventKindLLMRetry,
EventMeta{
runtimeevents.KindAgentLLMRetry,
HookMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.retry",
turnContext: cloneTurnContext(turnCtx),
@@ -521,7 +531,7 @@ func (al *AgentLoop) askSideQuestion(
// Apply after_llm hooks
if al.hooks != nil {
llmResp, decision := al.hooks.AfterLLM(ctx, &LLMHookResponse{
Meta: EventMeta{
Meta: HookMeta{
Source: "askSideQuestion",
TracePath: "turn.llm.response",
turnContext: cloneTurnContext(turnCtx),
+167
View File
@@ -135,6 +135,16 @@ func (p *errorProvider) Chat(
return nil, errors.New("context_length_exceeded")
case "vision":
return nil, errors.New("vision_unsupported")
case "connection_reset":
return nil, errors.New("connection reset by peer")
case "broken_pipe":
return nil, errors.New("broken pipe")
case "read_tcp":
return nil, errors.New("read tcp 127.0.0.1:8080: connection reset")
case "eof":
return nil, errors.New("EOF")
case "connection_refused":
return nil, errors.New("connection refused")
default:
return nil, errors.New("unknown error")
}
@@ -366,6 +376,163 @@ func TestPipeline_CallLLM_ContextLengthError(t *testing.T) {
t.Logf("CallLLM result after context error: err=%v", err)
}
func TestPipeline_CallLLM_NetworkErrorRetry(t *testing.T) {
testCases := []struct {
name string
errType string
}{
{"connection_reset", "connection_reset"},
{"broken_pipe", "broken_pipe"},
{"read_tcp", "read_tcp"},
{"eof", "eof"},
{"connection_refused", "connection_refused"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
errorPrv := &errorProvider{errType: tc.errType}
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err == nil {
t.Error("expected error after network error retries")
}
})
}
}
func TestPipeline_CallLLM_RetryConfigRespected(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxLLMRetries: 3,
LLMRetryBackoffSecs: 1,
},
},
}
msgBus := bus.NewMessageBus()
provider := &errorProvider{errType: "connection_reset"}
al := NewAgentLoop(cfg, msgBus, provider)
defer al.Close()
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
start := time.Now()
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
elapsed := time.Since(start)
if err == nil {
t.Error("expected error after retries")
}
expectedMinTime := 3 * time.Second
if elapsed < expectedMinTime {
t.Errorf("expected at least %v of backoff, got %v", expectedMinTime, elapsed)
}
}
func TestPipeline_CallLLM_RetryCountLimit(t *testing.T) {
tmpDir := t.TempDir()
counterPrv := &countingErrorProvider{errType: "connection_reset", targetCalls: 5}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxLLMRetries: 2,
LLMRetryBackoffSecs: 0,
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, counterPrv)
defer al.Close()
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err == nil {
t.Error("expected error after retries")
}
if counterPrv.callCount != 3 {
t.Errorf("expected exactly 3 calls (1 initial + 2 retries), got %d", counterPrv.callCount)
}
}
type countingErrorProvider struct {
errType string
targetCalls int
callCount int
mu sync.Mutex
}
func (p *countingErrorProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.callCount++
p.mu.Unlock()
return nil, errors.New("connection reset by peer")
}
func (p *countingErrorProvider) GetDefaultModel() string {
return "counting-error-model"
}
// =============================================================================
// Pipeline Method Tests: ExecuteTools
// =============================================================================
+6 -3
View File
@@ -256,7 +256,10 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
// Bind session store and capture initial history length for rollback logic
if agent != nil && agent.Sessions != nil {
ts.session = agent.Sessions
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey)
ts.initialHistoryLength = len(history)
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey)
}
return ts
@@ -442,9 +445,9 @@ func (ts *turnState) hardAbortRequested() bool {
return ts.hardAbort
}
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
func (ts *turnState) eventMeta(source, tracePath string) HookMeta {
snap := ts.snapshot()
return EventMeta{
return HookMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
+4 -3
View File
@@ -82,7 +82,8 @@ Notes:
"model_list": [
{
"model_name": "elevenlabs-asr",
"model": "elevenlabs/scribe_v1"
"provider": "elevenlabs",
"model": "scribe_v1"
}
]
}
@@ -130,7 +131,7 @@ PicoClaw currently supports three main ASR routes:
| Route | Example models | Behavior |
| --- | --- | --- |
| ElevenLabs ASR | `elevenlabs/scribe_v1` | Uses the ElevenLabs transcription API. |
| ElevenLabs ASR | `provider: elevenlabs`, `model: scribe_v1` | Uses the ElevenLabs transcription API. |
| Whisper endpoint models | `openai/whisper-1`, `groq/whisper-large-v3` | Uses an OpenAI-compatible `/audio/transcriptions` endpoint. |
| Audio-capable chat models **(Under construction)** | `openai/gpt-4o-audio-preview`, `gemini/gemini-2.5-flash` | Sends audio to a multimodal chat model and asks it to transcribe. |
@@ -142,7 +143,7 @@ If you are unsure which one to pick, choose Groq Whisper or ElevenLabs first.
1. **Preferred path**: resolve `voice.model_name` against `model_list`.
2. If that resolved model is:
- `elevenlabs/...`, PicoClaw uses the ElevenLabs transcriber.
- an `elevenlabs` provider model, PicoClaw uses the ElevenLabs transcriber.
- an OpenAI-compatible Whisper model, PicoClaw uses the Whisper transcriber.
- an audio-capable chat model, PicoClaw uses `AudioModelTranscriber`.
3. **Fallback path**: if `voice.model_name` is not set, PicoClaw performs a compatibility scan through `model_list` for legacy auto-detected ASR entries.
+4 -3
View File
@@ -82,7 +82,8 @@ model_list:
"model_list": [
{
"model_name": "elevenlabs-asr",
"model": "elevenlabs/scribe_v1"
"provider": "elevenlabs",
"model": "scribe_v1"
}
]
}
@@ -130,7 +131,7 @@ PicoClaw 目前主要支持三种 ASR 路径:
| 路径 | 示例模型 | 行为说明 |
| --- | --- | --- |
| ElevenLabs ASR | `elevenlabs/scribe_v1` | 使用 ElevenLabs 的语音转录接口。 |
| ElevenLabs ASR | `provider: elevenlabs``model: scribe_v1` | 使用 ElevenLabs 的语音转录接口。 |
| Whisper 接口模型 | `openai/whisper-1``groq/whisper-large-v3` | 使用 OpenAI 兼容的 `/audio/transcriptions` 接口。 |
| 支持音频的聊天模型 **(重构中)** | `openai/gpt-4o-audio-preview``gemini/gemini-2.5-flash` | 把音频发给多模态聊天模型,并要求它返回转录结果。 |
@@ -142,7 +143,7 @@ PicoClaw 目前主要支持三种 ASR 路径:
1. **首选路径**:根据 `voice.model_name``model_list` 中找到对应模型。
2. 如果找到的模型属于以下类型:
- `elevenlabs/...`,则使用 ElevenLabs transcriber。
- `provider=elevenlabs` 的模型,则使用 ElevenLabs transcriber。
- OpenAI 兼容的 Whisper 模型,则使用 Whisper transcriber。
- 支持音频输入的聊天模型,则使用 `AudioModelTranscriber`
3. **回退路径**:如果没有设置 `voice.model_name`,PicoClaw 会为了兼容旧配置,扫描 `model_list` 中可自动识别的 ASR 条目。
+21 -6
View File
@@ -8,6 +8,12 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
)
const elevenLabsSupportedModelID = "scribe_v1"
func ElevenLabsSupportedModelID() string {
return elevenLabsSupportedModelID
}
type Transcriber interface {
Name() string
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
@@ -72,14 +78,23 @@ func whisperModelID(modelCfg *config.ModelConfig) string {
return ""
}
func isElevenLabsTranscriptionModel(modelCfg *config.ModelConfig) bool {
if modelCfg == nil || modelCfg.APIKey() == "" {
return false
}
protocol, _ := providers.ExtractProtocol(modelCfg)
return protocol == "elevenlabs"
}
func transcriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber {
if modelCfg == nil {
return nil
}
protocol, _ := providers.ExtractProtocol(modelCfg)
if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
if isElevenLabsTranscriptionModel(modelCfg) {
_, modelID := providers.ExtractProtocol(modelCfg)
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase, modelID)
}
if modelID := whisperModelID(modelCfg); modelID != "" {
return NewWhisperTranscriber(modelCfg)
@@ -95,9 +110,9 @@ func fallbackTranscriberFromModelConfig(modelCfg *config.ModelConfig) Transcribe
return nil
}
protocol, _ := providers.ExtractProtocol(modelCfg)
if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
if isElevenLabsTranscriptionModel(modelCfg) {
_, modelID := providers.ExtractProtocol(modelCfg)
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase, modelID)
}
if modelID := whisperModelID(modelCfg); modelID != "" {
return NewWhisperTranscriber(modelCfg)
+15
View File
@@ -46,6 +46,21 @@ func TestDetectTranscriber(t *testing.T) {
},
wantName: "elevenlabs",
},
{
name: "explicit elevenlabs provider selects elevenlabs transcriber",
cfg: &config.Config{
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
ModelList: []*config.ModelConfig{
{
ModelName: "my-asr-model",
Provider: "elevenlabs",
Model: "scribe_v1",
APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test"),
},
},
},
wantName: "elevenlabs",
},
{
name: "voice model name alias selects whisper transcriber for groq",
cfg: &config.Config{
+7 -2
View File
@@ -20,19 +20,24 @@ import (
type ElevenLabsTranscriber struct {
apiKey string
apiBase string
modelID string
httpClient *http.Client
}
func NewElevenLabsTranscriber(apiKey, apiBase string) *ElevenLabsTranscriber {
func NewElevenLabsTranscriber(apiKey, apiBase, modelID string) *ElevenLabsTranscriber {
logger.DebugCF("voice", "Creating ElevenLabs transcriber", map[string]any{"has_api_key": apiKey != ""})
if apiBase == "" {
apiBase = "https://api.elevenlabs.io"
}
if modelID == "" || modelID != ElevenLabsSupportedModelID() {
modelID = ElevenLabsSupportedModelID()
}
return &ElevenLabsTranscriber{
apiKey: apiKey,
apiBase: apiBase,
modelID: modelID,
httpClient: &http.Client{
Timeout: 120 * time.Second,
},
@@ -74,7 +79,7 @@ func (t *ElevenLabsTranscriber) Transcribe(ctx context.Context, audioFilePath st
return nil, fmt.Errorf("failed to copy file content: %w", err)
}
if err = writer.WriteField("model_id", "scribe_v1"); err != nil {
if err = writer.WriteField("model_id", t.modelID); err != nil {
return nil, fmt.Errorf("failed to write model_id field: %w", err)
}
+81 -4
View File
@@ -3,10 +3,14 @@ package asr
import (
"context"
"encoding/json"
"io"
"mime"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
@@ -14,7 +18,7 @@ import (
var _ Transcriber = (*ElevenLabsTranscriber)(nil)
func TestElevenLabsTranscriberName(t *testing.T) {
tr := NewElevenLabsTranscriber("sk_test", "")
tr := NewElevenLabsTranscriber("sk_test", "", "scribe_v1")
if got := tr.Name(); got != "elevenlabs" {
t.Errorf("Name() = %q, want %q", got, "elevenlabs")
}
@@ -35,6 +39,35 @@ func TestElevenLabsTranscribe(t *testing.T) {
if r.Header.Get("Xi-Api-Key") != "sk_test" {
t.Errorf("unexpected xi-api-key header: %s", r.Header.Get("Xi-Api-Key"))
}
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
t.Fatalf("ParseMediaType() error = %v", err)
}
if mediaType != "multipart/form-data" {
t.Fatalf("content-type = %q, want multipart/form-data", mediaType)
}
reader := multipart.NewReader(r.Body, params["boundary"])
var gotModelID string
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("NextPart() error = %v", err)
}
if part.FormName() != "model_id" {
continue
}
body, err := io.ReadAll(part)
if err != nil {
t.Fatalf("ReadAll(part) error = %v", err)
}
gotModelID = strings.TrimSpace(string(body))
}
if gotModelID != "scribe_v1" {
t.Fatalf("model_id = %q, want %q", gotModelID, "scribe_v1")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
Text: "hello from elevenlabs",
@@ -43,7 +76,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
}))
defer srv.Close()
tr := NewElevenLabsTranscriber("sk_test", "")
tr := NewElevenLabsTranscriber("sk_test", "", "scribe_v1")
tr.apiBase = srv.URL
resp, err := tr.Transcribe(context.Background(), audioPath)
@@ -64,7 +97,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
}))
defer srv.Close()
tr := NewElevenLabsTranscriber("sk_bad", "")
tr := NewElevenLabsTranscriber("sk_bad", "", "scribe_v1")
tr.apiBase = srv.URL
_, err := tr.Transcribe(context.Background(), audioPath)
@@ -74,10 +107,54 @@ func TestElevenLabsTranscribe(t *testing.T) {
})
t.Run("missing file", func(t *testing.T) {
tr := NewElevenLabsTranscriber("sk_test", "")
tr := NewElevenLabsTranscriber("sk_test", "", "scribe_v1")
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
})
t.Run("unsupported model falls back to scribe_v1", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
t.Fatalf("ParseMediaType() error = %v", err)
}
if mediaType != "multipart/form-data" {
t.Fatalf("content-type = %q, want multipart/form-data", mediaType)
}
reader := multipart.NewReader(r.Body, params["boundary"])
var gotModelID string
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("NextPart() error = %v", err)
}
if part.FormName() != "model_id" {
continue
}
body, err := io.ReadAll(part)
if err != nil {
t.Fatalf("ReadAll(part) error = %v", err)
}
gotModelID = strings.TrimSpace(string(body))
}
if gotModelID != "scribe_v1" {
t.Fatalf("model_id = %q, want runtime fallback to %q", gotModelID, "scribe_v1")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TranscriptionResponse{Text: "ok"})
}))
defer srv.Close()
tr := NewElevenLabsTranscriber("sk_test", "", "unsupported-model")
tr.apiBase = srv.URL
if _, err := tr.Transcribe(context.Background(), audioPath); err != nil {
t.Fatalf("Transcribe() error: %v", err)
}
})
}
+44 -5
View File
@@ -6,6 +6,7 @@ import (
"sync"
"sync/atomic"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -48,6 +49,13 @@ type MessageBus struct {
closed atomic.Bool
wg sync.WaitGroup
streamDelegate atomic.Value // stores StreamDelegate
eventPublisher atomic.Value // stores EventPublisher
}
// EventPublisher is the minimal runtime event publisher used by MessageBus.
type EventPublisher interface {
Publish(ctx context.Context, evt runtimeevents.Event) runtimeevents.PublishResult
PublishNonBlocking(evt runtimeevents.Event) runtimeevents.PublishResult
}
func NewMessageBus() *MessageBus {
@@ -92,9 +100,14 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
msg = NormalizeInboundMessage(msg)
if msg.Context.isZero() {
mb.publishFailure("inbound", runtimeScopeFromInboundContext(msg.Context), ErrMissingInboundContext)
return ErrMissingInboundContext
}
return publish(ctx, mb, mb.inbound, msg)
if err := publish(ctx, mb, mb.inbound, msg); err != nil {
mb.publishFailure("inbound", runtimeScopeFromInboundContext(msg.Context), err)
return err
}
return nil
}
func (mb *MessageBus) InboundChan() <-chan InboundMessage {
@@ -104,9 +117,14 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage {
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
msg = NormalizeOutboundMessage(msg)
if msg.Context.isZero() {
mb.publishFailure("outbound", runtimeScopeFromInboundContext(msg.Context), ErrMissingOutboundContext)
return ErrMissingOutboundContext
}
return publish(ctx, mb, mb.outbound, msg)
if err := publish(ctx, mb, mb.outbound, msg); err != nil {
mb.publishFailure("outbound", runtimeScopeFromInboundContext(msg.Context), err)
return err
}
return nil
}
func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
@@ -116,9 +134,14 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
msg = NormalizeOutboundMediaMessage(msg)
if msg.Context.isZero() {
mb.publishFailure("outbound_media", runtimeScopeFromInboundContext(msg.Context), ErrMissingOutboundMediaContext)
return ErrMissingOutboundMediaContext
}
return publish(ctx, mb, mb.outboundMedia, msg)
if err := publish(ctx, mb, mb.outboundMedia, msg); err != nil {
mb.publishFailure("outbound_media", runtimeScopeFromInboundContext(msg.Context), err)
return err
}
return nil
}
func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
@@ -126,7 +149,11 @@ func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
}
func (mb *MessageBus) PublishAudioChunk(ctx context.Context, chunk AudioChunk) error {
return publish(ctx, mb, mb.audioChunks, chunk)
if err := publish(ctx, mb, mb.audioChunks, chunk); err != nil {
mb.publishFailure("audio_chunk", runtimeScopeFromAudioChunk(chunk), err)
return err
}
return nil
}
func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk {
@@ -134,7 +161,11 @@ func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk {
}
func (mb *MessageBus) PublishVoiceControl(ctx context.Context, ctrl VoiceControl) error {
return publish(ctx, mb, mb.voiceControls, ctrl)
if err := publish(ctx, mb, mb.voiceControls, ctrl); err != nil {
mb.publishFailure("voice_control", runtimeScopeFromVoiceControl(ctrl), err)
return err
}
return nil
}
func (mb *MessageBus) VoiceControlsChan() <-chan VoiceControl {
@@ -146,6 +177,11 @@ func (mb *MessageBus) SetStreamDelegate(d StreamDelegate) {
mb.streamDelegate.Store(d)
}
// SetEventPublisher registers a runtime event publisher for bus errors and lifecycle events.
func (mb *MessageBus) SetEventPublisher(p EventPublisher) {
mb.eventPublisher.Store(p)
}
// GetStreamer returns a Streamer for the given channel+chatID via the delegate.
func (mb *MessageBus) GetStreamer(ctx context.Context, channel, chatID string) (Streamer, bool) {
if d, ok := mb.streamDelegate.Load().(StreamDelegate); ok && d != nil {
@@ -156,6 +192,7 @@ func (mb *MessageBus) GetStreamer(ctx context.Context, channel, chatID string) (
func (mb *MessageBus) Close() {
mb.closeOnce.Do(func() {
mb.publishCloseEvent(runtimeevents.KindBusCloseStarted, 0)
// notify all blocked publishers to exit
close(mb.done)
@@ -195,6 +232,8 @@ func (mb *MessageBus) Close() {
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
"count": drained,
})
mb.publishCloseEvent(runtimeevents.KindBusCloseDrained, drained)
}
mb.publishCloseEvent(runtimeevents.KindBusCloseCompleted, drained)
})
}
+82
View File
@@ -5,6 +5,8 @@ import (
"sync"
"testing"
"time"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
func TestPublishConsume(t *testing.T) {
@@ -171,6 +173,86 @@ func TestPublishInbound_BackfillsContextFromLegacyFields(t *testing.T) {
}
}
func TestMessageBusPublishesRuntimeFailureAndCloseEvents(t *testing.T) {
eventBus := runtimeevents.NewBus()
defer func() {
if err := eventBus.Close(); err != nil {
t.Errorf("event bus close failed: %v", err)
}
}()
_, eventsCh, err := eventBus.Channel().OfKind(
runtimeevents.KindBusPublishFailed,
runtimeevents.KindBusCloseStarted,
runtimeevents.KindBusCloseDrained,
runtimeevents.KindBusCloseCompleted,
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "bus-events", Buffer: 4})
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
mb := NewMessageBus()
mb.SetEventPublisher(eventBus)
if err := mb.PublishInbound(context.Background(), InboundMessage{}); err == nil {
t.Fatal("expected PublishInbound to fail")
}
failed := receiveBusRuntimeEvent(t, eventsCh)
if failed.Kind != runtimeevents.KindBusPublishFailed ||
failed.Source.Name != "inbound" ||
failed.Severity != runtimeevents.SeverityError {
t.Fatalf("publish failed event = %+v", failed)
}
if failed.Attrs["stream"] != "inbound" || failed.Attrs["error"] == "" {
t.Fatalf("publish failed attrs = %#v, want stream and error", failed.Attrs)
}
if err := mb.PublishOutbound(context.Background(), OutboundMessage{
Context: NewOutboundContext("telegram", "chat-1", ""),
Content: "queued",
}); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
mb.Close()
seen := map[runtimeevents.Kind]bool{}
var drainedAttrs map[string]any
for range 3 {
evt := receiveBusRuntimeEvent(t, eventsCh)
seen[evt.Kind] = true
if evt.Kind == runtimeevents.KindBusCloseDrained {
drainedAttrs = evt.Attrs
}
}
for _, kind := range []runtimeevents.Kind{
runtimeevents.KindBusCloseStarted,
runtimeevents.KindBusCloseDrained,
runtimeevents.KindBusCloseCompleted,
} {
if !seen[kind] {
t.Fatalf("missing %s event, seen=%v", kind, seen)
}
}
if drainedAttrs["drained"] != 1 {
t.Fatalf("bus close drained attrs = %#v, want drained count", drainedAttrs)
}
}
func receiveBusRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
t.Helper()
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("runtime event channel closed before expected event")
}
return evt
case <-time.After(time.Second):
t.Fatal("timed out waiting for runtime event")
return runtimeevents.Event{}
}
}
func TestPublishOutboundSubscribe(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
+88
View File
@@ -0,0 +1,88 @@
package bus
import (
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
type busPublishFailedPayload struct {
Stream string `json:"stream"`
Error string `json:"error"`
}
type busClosePayload struct {
Drained int `json:"drained,omitempty"`
}
func (mb *MessageBus) publishFailure(stream string, scope runtimeevents.Scope, err error) {
if mb == nil || err == nil {
return
}
publisher, ok := mb.eventPublisher.Load().(EventPublisher)
if !ok || publisher == nil {
return
}
publisher.PublishNonBlocking(runtimeevents.Event{
Kind: runtimeevents.KindBusPublishFailed,
Source: runtimeevents.Source{Component: "bus", Name: stream},
Scope: scope,
Severity: runtimeevents.SeverityError,
Payload: busPublishFailedPayload{
Stream: stream,
Error: err.Error(),
},
Attrs: map[string]any{
"stream": stream,
"error": err.Error(),
},
})
}
func (mb *MessageBus) publishCloseEvent(kind runtimeevents.Kind, drained int) {
if mb == nil {
return
}
publisher, ok := mb.eventPublisher.Load().(EventPublisher)
if !ok || publisher == nil {
return
}
attrs := map[string]any{}
if drained > 0 {
attrs["drained"] = drained
}
publisher.PublishNonBlocking(runtimeevents.Event{
Kind: kind,
Source: runtimeevents.Source{Component: "bus"},
Severity: runtimeevents.SeverityInfo,
Payload: busClosePayload{Drained: drained},
Attrs: attrs,
})
}
func runtimeScopeFromInboundContext(ctx InboundContext) runtimeevents.Scope {
return runtimeevents.Scope{
Channel: ctx.Channel,
Account: ctx.Account,
ChatID: ctx.ChatID,
TopicID: ctx.TopicID,
SpaceID: ctx.SpaceID,
SpaceType: ctx.SpaceType,
ChatType: ctx.ChatType,
SenderID: ctx.SenderID,
MessageID: ctx.MessageID,
}
}
func runtimeScopeFromAudioChunk(chunk AudioChunk) runtimeevents.Scope {
return runtimeevents.Scope{
Channel: chunk.Channel,
ChatID: chunk.ChatID,
}
}
func runtimeScopeFromVoiceControl(ctrl VoiceControl) runtimeevents.Scope {
return runtimeevents.Scope{
ChatID: ctrl.ChatID,
}
}
+197
View File
@@ -0,0 +1,197 @@
package channels
import (
"github.com/sipeed/picoclaw/pkg/bus"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
func channelTypeForEvent(m *Manager, channelName string) string {
if m == nil || m.config == nil {
return channelName
}
if bc := m.config.Channels.Get(channelName); bc != nil && bc.Type != "" {
return bc.Type
}
return channelName
}
func (m *Manager) publishChannelEvent(
kind runtimeevents.Kind,
channelName string,
scope runtimeevents.Scope,
severity runtimeevents.Severity,
payload any,
) {
if m == nil || m.runtimeEvents == nil {
return
}
if scope.Channel == "" {
scope.Channel = channelName
}
m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{
Kind: kind,
Source: runtimeevents.Source{Component: "channel", Name: channelName},
Scope: scope,
Severity: severity,
Payload: payload,
Attrs: channelEventAttrs(payload),
})
}
func channelEventAttrs(payload any) map[string]any {
switch payload := payload.(type) {
case ChannelLifecyclePayload:
attrs := map[string]any{}
setAttrString(attrs, "type", payload.Type)
setAttrString(attrs, "error", payload.Error)
return attrs
case ChannelOutboundPayload:
attrs := map[string]any{}
if payload.Media {
attrs["media"] = payload.Media
}
if payload.ContentLen > 0 {
attrs["content_len"] = payload.ContentLen
}
if len(payload.MessageIDs) > 0 {
attrs["message_ids_count"] = len(payload.MessageIDs)
}
setAttrString(attrs, "reply_to_message_id", payload.ReplyToMessageID)
setAttrString(attrs, "error", payload.Error)
if payload.Retries > 0 {
attrs["retries"] = payload.Retries
}
return attrs
default:
return nil
}
}
func setAttrString(attrs map[string]any, key, value string) {
if value != "" {
attrs[key] = value
}
}
func (m *Manager) publishOutboundSent(
channelName string,
msg bus.OutboundMessage,
messageIDs []string,
) {
m.publishChannelEvent(
runtimeevents.KindChannelMessageOutboundSent,
channelName,
scopeFromOutboundContext(msg.Context),
runtimeevents.SeverityInfo,
ChannelOutboundPayload{
ContentLen: len([]rune(msg.Content)),
MessageIDs: append([]string(nil), messageIDs...),
ReplyToMessageID: msg.ReplyToMessageID,
},
)
}
func (m *Manager) publishOutboundQueued(
channelName string,
msg bus.OutboundMessage,
) {
m.publishChannelEvent(
runtimeevents.KindChannelMessageOutboundQueued,
channelName,
scopeFromOutboundContext(msg.Context),
runtimeevents.SeverityInfo,
ChannelOutboundPayload{
ContentLen: len([]rune(msg.Content)),
ReplyToMessageID: msg.ReplyToMessageID,
},
)
}
func (m *Manager) publishOutboundFailed(
channelName string,
msg bus.OutboundMessage,
err error,
media bool,
) {
payload := ChannelOutboundPayload{
Media: media,
ContentLen: len([]rune(msg.Content)),
ReplyToMessageID: msg.ReplyToMessageID,
Retries: maxRetries,
}
if err != nil {
payload.Error = err.Error()
}
m.publishChannelEvent(
runtimeevents.KindChannelMessageOutboundFailed,
channelName,
scopeFromOutboundContext(msg.Context),
runtimeevents.SeverityError,
payload,
)
}
func (m *Manager) publishOutboundMediaSent(
channelName string,
msg bus.OutboundMediaMessage,
messageIDs []string,
) {
m.publishChannelEvent(
runtimeevents.KindChannelMessageOutboundSent,
channelName,
scopeFromOutboundContext(msg.Context),
runtimeevents.SeverityInfo,
ChannelOutboundPayload{
Media: true,
MessageIDs: append([]string(nil), messageIDs...),
},
)
}
func (m *Manager) publishOutboundMediaQueued(
channelName string,
msg bus.OutboundMediaMessage,
) {
m.publishChannelEvent(
runtimeevents.KindChannelMessageOutboundQueued,
channelName,
scopeFromOutboundContext(msg.Context),
runtimeevents.SeverityInfo,
ChannelOutboundPayload{Media: true},
)
}
func (m *Manager) publishOutboundMediaFailed(
channelName string,
msg bus.OutboundMediaMessage,
err error,
) {
payload := ChannelOutboundPayload{
Media: true,
Retries: maxRetries,
}
if err != nil {
payload.Error = err.Error()
}
m.publishChannelEvent(
runtimeevents.KindChannelMessageOutboundFailed,
channelName,
scopeFromOutboundContext(msg.Context),
runtimeevents.SeverityError,
payload,
)
}
func scopeFromOutboundContext(ctx bus.InboundContext) runtimeevents.Scope {
return runtimeevents.Scope{
Channel: ctx.Channel,
Account: ctx.Account,
ChatID: ctx.ChatID,
TopicID: ctx.TopicID,
SpaceID: ctx.SpaceID,
SpaceType: ctx.SpaceType,
ChatType: ctx.ChatType,
SenderID: ctx.SenderID,
MessageID: ctx.MessageID,
}
}
+56
View File
@@ -64,6 +64,62 @@ func extractJSONStringField(content, field string) string {
// Format: {"image_key": "img_xxx"}
func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") }
// extractPostImageKeys extracts all image_key values from a Feishu post (rich text)
// message. Post messages have nested arrays of elements where images appear as
// {"tag":"img","image_key":"img_xxx"}.
func extractPostImageKeys(rawContent string) []string {
if rawContent == "" {
return nil
}
var post map[string]json.RawMessage
if err := json.Unmarshal([]byte(rawContent), &post); err != nil {
return nil
}
var keys []string
seen := make(map[string]struct{})
collectFromRows := func(contentRaw json.RawMessage) {
var rows [][]map[string]any
if err := json.Unmarshal(contentRaw, &rows); err != nil {
return
}
for _, row := range rows {
for _, elem := range row {
if tag, _ := elem["tag"].(string); tag == "img" {
if ik, _ := elem["image_key"].(string); ik != "" {
if _, dup := seen[ik]; !dup {
seen[ik] = struct{}{}
keys = append(keys, ik)
}
}
}
}
}
}
// Flat format: {"title":"...", "content":[[...]]}
if contentRaw, ok := post["content"]; ok {
collectFromRows(contentRaw)
}
// Localized format: {"zh_cn": {"title":"...", "content":[[...]]}, ...}
for _, raw := range post {
var locale map[string]json.RawMessage
if err := json.Unmarshal(raw, &locale); err != nil {
continue
}
contentRaw, ok := locale["content"]
if !ok {
continue
}
collectFromRows(contentRaw)
}
return keys
}
// extractFileKey extracts the file_key from a Feishu file/audio message content JSON.
// Format: {"file_key": "file_xxx", "file_name": "...", ...}
func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") }
+94
View File
@@ -291,6 +291,100 @@ func TestStripMentionPlaceholders(t *testing.T) {
}
}
func TestExtractPostImageKeys(t *testing.T) {
tests := []struct {
name string
content string
want []string
}{
{
name: "empty content",
content: "",
want: nil,
},
{
name: "invalid JSON",
content: "not json",
want: nil,
},
{
name: "post with no images",
content: `{"zh_cn":{"title":"Title","content":[[{"tag":"text","text":"hello"}]]}}`,
want: nil,
},
{
name: "post with one image",
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_v3_001"}]]}}`,
want: []string{"img_v3_001"},
},
{
name: "post with multiple images",
content: `{"zh_cn":{"title":"","content":[[{"tag":"text","text":"see"},{"tag":"img","image_key":"img_001"}],[{"tag":"img","image_key":"img_002"}]]}}`,
want: []string{"img_001", "img_002"},
},
{
name: "post with text and image mixed in row",
content: `{"zh_cn":{"title":"","content":[[{"tag":"text","text":"hi"},{"tag":"img","image_key":"img_mix"}]]}}`,
want: []string{"img_mix"},
},
{
name: "en_us locale",
content: `{"en_us":{"title":"","content":[[{"tag":"img","image_key":"img_en"}]]}}`,
want: []string{"img_en"},
},
{
name: "multiple locales with distinct images",
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_zh"}]]},"en_us":{"title":"","content":[[{"tag":"img","image_key":"img_en"}]]}}`,
want: []string{"img_zh", "img_en"},
},
{
name: "duplicate image_key across locales is deduplicated",
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_same"}]]},"en_us":{"title":"","content":[[{"tag":"img","image_key":"img_same"}]]}}`,
want: []string{"img_same"},
},
{
name: "image with empty image_key",
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":""}]]}}`,
want: nil,
},
{
name: "flat format without locale wrapper",
content: `{"title":"","content":[[{"tag":"img","image_key":"img_v3_flat","width":1826,"height":338}],[{"tag":"text","text":" check this image","style":[]}]]}`,
want: []string{"img_v3_flat"},
},
{
name: "flat format multiple images",
content: `{"title":"","content":[[{"tag":"img","image_key":"img_flat_1"}],[{"tag":"img","image_key":"img_flat_2"},{"tag":"text","text":"desc"}]]}`,
want: []string{"img_flat_1", "img_flat_2"},
},
{
name: "flat format no images",
content: `{"title":"Test","content":[[{"tag":"text","text":"just text"}]]}`,
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractPostImageKeys(tt.content)
if len(got) != len(tt.want) {
t.Errorf("extractPostImageKeys() = %v, want %v", got, tt.want)
return
}
// Use set comparison to avoid map iteration order dependency
gotSet := make(map[string]bool, len(got))
for _, v := range got {
gotSet[v] = true
}
for _, v := range tt.want {
if !gotSet[v] {
t.Errorf("extractPostImageKeys() missing expected key %q; got %v", v, got)
}
}
})
}
}
func TestExtractCardImageKeys(t *testing.T) {
tests := []struct {
name string
+100 -24
View File
@@ -803,6 +803,14 @@ func (c *FeishuChannel) downloadInboundMedia(
refs = append(refs, ref)
}
case larkim.MsgTypePost:
for _, imageKey := range extractPostImageKeys(rawContent) {
ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope)
if ref != "" {
refs = append(refs, ref)
}
}
case larkim.MsgTypeInteractive:
// Extract and download images embedded in interactive cards
feishuKeys, _ := extractCardImageKeys(rawContent)
@@ -842,12 +850,41 @@ func (c *FeishuChannel) downloadInboundMedia(
// downloadResource downloads a message resource (image/file) from Feishu,
// writes it to the project media directory, and stores the reference in MediaStore.
// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension.
//
// For image resources, if the primary MessageResource.Get API fails (which
// requires im:message or im:message:readonly scope), a fallback to the
// Image.Get API (which requires im:resource scope) is attempted. This ensures
// image downloads succeed regardless of which permission the user has granted.
func (c *FeishuChannel) downloadResource(
ctx context.Context,
messageID, fileKey, resourceType, fallbackExt string,
store media.MediaStore,
scope string,
) string {
file, filename := c.fetchResourceData(ctx, messageID, fileKey, resourceType)
if file == nil {
return ""
}
if closer, ok := file.(io.Closer); ok {
defer closer.Close()
}
if filename == "" {
filename = fileKey
}
if filepath.Ext(filename) == "" && fallbackExt != "" {
filename += fallbackExt
}
return c.storeResourceFile(ctx, messageID, fileKey, filename, file, store, scope)
}
// fetchResourceData tries to download a resource from Feishu, first via
// MessageResource.Get, then falling back to Image.Get for image resources.
func (c *FeishuChannel) fetchResourceData(
ctx context.Context,
messageID, fileKey, resourceType string,
) (io.Reader, string) {
req := larkim.NewGetMessageResourceReqBuilder().
MessageId(messageID).
FileKey(fileKey).
@@ -855,41 +892,80 @@ func (c *FeishuChannel) downloadResource(
Build()
resp, err := c.client.Im.V1.MessageResource.Get(ctx, req)
if err == nil && resp.Success() && resp.File != nil {
return resp.File, resp.FileName
}
if err != nil {
logger.ErrorCF("feishu", "Failed to download resource", map[string]any{
logger.WarnCF("feishu", "MessageResource.Get failed", map[string]any{
"message_id": messageID,
"file_key": fileKey,
"error": err.Error(),
})
return ""
} else if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
logger.WarnCF("feishu", "MessageResource.Get api error", map[string]any{
"message_id": messageID,
"file_key": fileKey,
"code": resp.Code,
"msg": resp.Msg,
})
} else {
logger.WarnCF("feishu", "MessageResource.Get returned empty file body", map[string]any{
"message_id": messageID,
"file_key": fileKey,
})
}
if resourceType != "image" {
return nil, ""
}
return c.fetchImageDirect(ctx, fileKey)
}
// fetchImageDirect downloads an image using the Image.Get API
// (/open-apis/im/v1/images/:image_key), which requires the im:resource scope.
func (c *FeishuChannel) fetchImageDirect(ctx context.Context, imageKey string) (io.Reader, string) {
req := larkim.NewGetImageReqBuilder().
ImageKey(imageKey).
Build()
resp, err := c.client.Im.V1.Image.Get(ctx, req)
if err != nil {
logger.ErrorCF("feishu", "Image.Get fallback failed", map[string]any{
"image_key": imageKey,
"error": err.Error(),
})
return nil, ""
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
logger.ErrorCF("feishu", "Resource download api error", map[string]any{
"code": resp.Code,
"msg": resp.Msg,
logger.ErrorCF("feishu", "Image.Get fallback api error", map[string]any{
"image_key": imageKey,
"code": resp.Code,
"msg": resp.Msg,
})
return ""
return nil, ""
}
if resp.File == nil {
return ""
}
// Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body).
if closer, ok := resp.File.(io.Closer); ok {
defer closer.Close()
return nil, ""
}
filename := resp.FileName
if filename == "" {
filename = fileKey
}
// If filename still has no extension, append the fallback (like Telegram's ext parameter).
if filepath.Ext(filename) == "" && fallbackExt != "" {
filename += fallbackExt
}
logger.DebugCF("feishu", "Image downloaded via Image.Get fallback", map[string]any{
"image_key": imageKey,
})
return resp.File, resp.FileName
}
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
// storeResourceFile writes downloaded resource data to disk and registers it in the MediaStore.
func (c *FeishuChannel) storeResourceFile(
ctx context.Context,
messageID, fileKey, filename string,
file io.Reader,
store media.MediaStore,
scope string,
) string {
mediaDir := media.TempDir()
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
@@ -908,7 +984,7 @@ func (c *FeishuChannel) downloadResource(
return ""
}
if _, copyErr := io.Copy(out, resp.File); copyErr != nil {
if _, copyErr := io.Copy(out, file); copyErr != nil {
out.Close()
os.Remove(localPath)
logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{
@@ -943,8 +1019,8 @@ func appendMediaTags(content, messageType string, mediaRefs []string) string {
return content
}
// Don't append tags to JSON content (interactive cards) - would produce invalid JSON
if messageType == larkim.MsgTypeInteractive {
// Don't append tags to JSON content - would produce invalid JSON
if messageType == larkim.MsgTypeInteractive || messageType == larkim.MsgTypePost {
return content
}
+7
View File
@@ -180,6 +180,13 @@ func TestAppendMediaTags(t *testing.T) {
mediaRefs: []string{"ref1"},
want: `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`,
},
{
name: "post message with images returns content unchanged",
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_001"}]]}}`,
messageType: "post",
mediaRefs: []string{"ref1"},
want: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_001"}]]}}`,
},
}
for _, tt := range tests {
+27 -10
View File
@@ -61,7 +61,10 @@ func NewLINEChannel(
return nil, fmt.Errorf("line channel_secret and channel_access_token are required")
}
client, err := messaging_api.NewMessagingApiAPI(cfg.ChannelAccessToken.String())
client, err := messaging_api.NewMessagingApiAPI(
cfg.ChannelAccessToken.String(),
messaging_api.WithHTTPClient(&http.Client{Timeout: 30 * time.Second}),
)
if err != nil {
return nil, fmt.Errorf("failed to create LINE messaging client: %w", err)
}
@@ -456,7 +459,7 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok {
tokenEntry := entry.(replyTokenEntry)
if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge {
_, err := c.client.WithContext(ctx).ReplyMessage(&messaging_api.ReplyMessageRequest{
_, _, err := c.client.WithContext(ctx).ReplyMessageWithHttpInfo(&messaging_api.ReplyMessageRequest{
ReplyToken: tokenEntry.token,
Messages: []messaging_api.MessageInterface{&textMsg},
})
@@ -467,16 +470,18 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
})
return nil, nil
}
logger.DebugC("line", "Reply API failed, falling back to Push API")
logger.DebugCF("line", "Reply API failed, falling back to Push API", map[string]any{
"error": err.Error(),
})
}
}
// Fall back to Push API
_, err := c.client.WithContext(ctx).PushMessage(&messaging_api.PushMessageRequest{
resp, _, err := c.client.WithContext(ctx).PushMessageWithHttpInfo(&messaging_api.PushMessageRequest{
To: msg.ChatID,
Messages: []messaging_api.MessageInterface{&textMsg},
}, "")
return nil, err
return nil, classifySDKError(resp, err)
}
// SendMedia implements the channels.MediaSender interface.
@@ -502,11 +507,12 @@ func (c *LINEChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessag
}
textMsg := messaging_api.TextMessage{Text: caption}
if _, err := c.client.WithContext(ctx).PushMessage(&messaging_api.PushMessageRequest{
resp, _, err := c.client.WithContext(ctx).PushMessageWithHttpInfo(&messaging_api.PushMessageRequest{
To: msg.ChatID,
Messages: []messaging_api.MessageInterface{&textMsg},
}, ""); err != nil {
return nil, err
}, "")
if err != nil {
return nil, classifySDKError(resp, err)
}
}
@@ -558,13 +564,24 @@ func (c *LINEChannel) StartTyping(ctx context.Context, chatID string) (func(), e
return stop, nil
}
// classifySDKError maps an SDK HTTP response to the project's sentinel errors.
func classifySDKError(resp *http.Response, err error) error {
if err == nil {
return nil
}
if resp != nil {
return channels.ClassifySendError(resp.StatusCode, err)
}
return channels.ClassifyNetError(err)
}
// sendLoading sends a loading animation indicator to the chat.
func (c *LINEChannel) sendLoading(ctx context.Context, chatID string) error {
_, err := c.client.WithContext(ctx).ShowLoadingAnimation(&messaging_api.ShowLoadingAnimationRequest{
resp, _, err := c.client.WithContext(ctx).ShowLoadingAnimationWithHttpInfo(&messaging_api.ShowLoadingAnimationRequest{
ChatId: chatID,
LoadingSeconds: 60,
})
return err
return classifySDKError(resp, err)
}
// downloadContent downloads media content from the LINE content API.
+140 -1
View File
@@ -23,6 +23,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
@@ -84,6 +85,7 @@ type Manager struct {
channels map[string]Channel
workers map[string]*channelWorker
bus *bus.MessageBus
runtimeEvents runtimeevents.Bus
config *config.Config
mediaStore media.MediaStore
dispatchTask *asyncTask
@@ -98,6 +100,32 @@ type Manager struct {
channelHashes map[string]string // channel name → config hash
}
// ManagerOption configures a channel Manager.
type ManagerOption func(*Manager)
// WithRuntimeEvents injects the runtime event bus used for channel observations.
func WithRuntimeEvents(eventBus runtimeevents.Bus) ManagerOption {
return func(m *Manager) {
m.runtimeEvents = eventBus
}
}
// ChannelLifecyclePayload describes channel lifecycle runtime events.
type ChannelLifecyclePayload struct {
Type string `json:"type,omitempty"`
Error string `json:"error,omitempty"`
}
// ChannelOutboundPayload describes channel outbound message runtime events.
type ChannelOutboundPayload struct {
Media bool `json:"media,omitempty"`
ContentLen int `json:"content_len,omitempty"`
MessageIDs []string `json:"message_ids,omitempty"`
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
Error string `json:"error,omitempty"`
Retries int `json:"retries,omitempty"`
}
type toolFeedbackMessageTracker interface {
RecordToolFeedbackMessage(chatID, messageID, content string)
ClearToolFeedbackMessage(chatID string)
@@ -192,6 +220,21 @@ func clearTrackedToolFeedbackMessage(
}
}
// DismissToolFeedback clears any tracked tool feedback animation for the
// given channel/chat. This is called when a turn ends without a final
// response (e.g., ResponseHandled tools) to stop orphaned animation goroutines.
// outboundCtx carries topic/thread info for channels that use scoped tracker
// keys (e.g., Telegram forum topics); may be nil for non-topic channels.
func (m *Manager) DismissToolFeedback(
ctx context.Context, channelName, chatID string, outboundCtx *bus.InboundContext,
) {
ch, ok := m.GetChannel(channelName)
if !ok {
return
}
dismissTrackedToolFeedbackMessage(ctx, ch, chatID, outboundCtx)
}
func prepareToolFeedbackMessageContent(ch Channel, content string) string {
prepared := strings.TrimSpace(content)
if prepared == "" {
@@ -409,7 +452,12 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun
}
}
func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
func NewManager(
cfg *config.Config,
messageBus *bus.MessageBus,
store media.MediaStore,
opts ...ManagerOption,
) (*Manager, error) {
m := &Manager{
channels: make(map[string]Channel),
workers: make(map[string]*channelWorker),
@@ -418,6 +466,11 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.Medi
mediaStore: store,
channelHashes: make(map[string]string),
}
for _, opt := range opts {
if opt != nil {
opt(m)
}
}
// Register as streaming delegate so the agent loop can obtain streamers
messageBus.SetStreamDelegate(m)
@@ -542,6 +595,13 @@ func (m *Manager) initChannel(typeName, channelName string) {
setter.SetOwner(ch)
}
m.channels[channelName] = ch
m.publishChannelEvent(
runtimeevents.KindChannelLifecycleInitialized,
channelName,
runtimeevents.Scope{Channel: channelName},
runtimeevents.SeverityInfo,
ChannelLifecyclePayload{Type: typeName},
)
logger.InfoCF("channels", "Channel enabled successfully", map[string]any{
"channel": channelName,
"type": typeName,
@@ -687,6 +747,13 @@ func (m *Manager) registerHTTPHandlersLocked() {
func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) {
if wh, ok := ch.(WebhookHandler); ok {
m.mux.Handle(wh.WebhookPath(), wh)
m.publishChannelEvent(
runtimeevents.KindChannelWebhookRegistered,
name,
runtimeevents.Scope{Channel: name},
runtimeevents.SeverityInfo,
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
)
logger.InfoCF("channels", "Webhook handler registered", map[string]any{
"channel": name,
"path": wh.WebhookPath(),
@@ -706,6 +773,13 @@ func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) {
func (m *Manager) unregisterChannelHTTPHandler(name string, ch Channel) {
if wh, ok := ch.(WebhookHandler); ok {
m.mux.Unhandle(wh.WebhookPath())
m.publishChannelEvent(
runtimeevents.KindChannelWebhookUnregistered,
name,
runtimeevents.Scope{Channel: name},
runtimeevents.SeverityInfo,
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
)
logger.InfoCF("channels", "Webhook handler unregistered", map[string]any{
"channel": name,
"path": wh.WebhookPath(),
@@ -744,6 +818,13 @@ func (m *Manager) StartAll(ctx context.Context) error {
"channel": name,
"error": err.Error(),
})
m.publishChannelEvent(
runtimeevents.KindChannelLifecycleStartFailed,
name,
runtimeevents.Scope{Channel: name},
runtimeevents.SeverityError,
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name), Error: err.Error()},
)
failedStarts = append(failedStarts, fmt.Errorf("channel %s: %w", name, err))
failedNames = append(failedNames, name)
continue
@@ -759,6 +840,13 @@ func (m *Manager) StartAll(ctx context.Context) error {
m.workers[name] = w
go m.runWorker(dispatchCtx, name, w)
go m.runMediaWorker(dispatchCtx, name, w)
m.publishChannelEvent(
runtimeevents.KindChannelLifecycleStarted,
name,
runtimeevents.Scope{Channel: name},
runtimeevents.SeverityInfo,
ChannelLifecyclePayload{Type: channelType},
)
}
if len(m.channels) > 0 && len(m.workers) == 0 {
@@ -895,7 +983,15 @@ func (m *Manager) StopAll(ctx context.Context) error {
"channel": name,
"error": err.Error(),
})
continue
}
m.publishChannelEvent(
runtimeevents.KindChannelLifecycleStopped,
name,
runtimeevents.Scope{Channel: name},
runtimeevents.SeverityInfo,
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
)
}
logger.InfoC("channels", "All channels stopped")
@@ -1005,11 +1101,23 @@ func (m *Manager) sendWithRetry(
// Rate limit: wait for token
if err := w.limiter.Wait(ctx); err != nil {
// ctx canceled, shutting down
m.publishChannelEvent(
runtimeevents.KindChannelRateLimited,
name,
scopeFromOutboundContext(msg.Context),
runtimeevents.SeverityWarn,
ChannelOutboundPayload{
ContentLen: len([]rune(msg.Content)),
ReplyToMessageID: msg.ReplyToMessageID,
Error: err.Error(),
},
)
return nil, false
}
// Pre-send: stop typing and try to edit placeholder
if msgIDs, handled := m.preSend(ctx, name, msg, w.ch); handled {
m.publishOutboundSent(name, msg, msgIDs)
return msgIDs, true
}
@@ -1018,6 +1126,7 @@ func (m *Manager) sendWithRetry(
for attempt := 0; attempt <= maxRetries; attempt++ {
msgIDs, lastErr = w.ch.Send(ctx, msg)
if lastErr == nil {
m.publishOutboundSent(name, msg, msgIDs)
return msgIDs, true
}
@@ -1057,6 +1166,7 @@ func (m *Manager) sendWithRetry(
"error": lastErr.Error(),
"retries": maxRetries,
})
m.publishOutboundFailed(name, msg, lastErr, false)
return nil, false
}
@@ -1119,6 +1229,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select {
case w.queue <- msg:
m.publishOutboundQueued(outboundMessageChannel(msg), msg)
return true
case <-ctx.Done():
return false
@@ -1139,6 +1250,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
case w.mediaQueue <- msg:
m.publishOutboundMediaQueued(outboundMediaChannel(msg), msg)
return true
case <-ctx.Done():
return false
@@ -1188,6 +1300,16 @@ func (m *Manager) sendMediaWithRetry(
// Rate limit: wait for token
if err := w.limiter.Wait(ctx); err != nil {
m.publishChannelEvent(
runtimeevents.KindChannelRateLimited,
name,
scopeFromOutboundContext(msg.Context),
runtimeevents.SeverityWarn,
ChannelOutboundPayload{
Media: true,
Error: err.Error(),
},
)
return nil, err
}
@@ -1199,6 +1321,7 @@ func (m *Manager) sendMediaWithRetry(
for attempt := 0; attempt <= maxRetries; attempt++ {
msgIDs, lastErr = ms.SendMedia(ctx, msg)
if lastErr == nil {
m.publishOutboundMediaSent(name, msg, msgIDs)
return msgIDs, nil
}
@@ -1238,6 +1361,7 @@ func (m *Manager) sendMediaWithRetry(
"error": lastErr.Error(),
"retries": maxRetries,
})
m.publishOutboundMediaFailed(name, msg, lastErr)
return nil, lastErr
}
@@ -1375,6 +1499,13 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
"channel": name,
"error": err.Error(),
})
m.publishChannelEvent(
runtimeevents.KindChannelLifecycleStartFailed,
name,
runtimeevents.Scope{Channel: name},
runtimeevents.SeverityError,
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name), Error: err.Error()},
)
continue
}
// Lazily create worker only after channel starts successfully
@@ -1388,6 +1519,13 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
m.workers[name] = w
go m.runWorker(dispatchCtx, name, w)
go m.runMediaWorker(dispatchCtx, name, w)
m.publishChannelEvent(
runtimeevents.KindChannelLifecycleStarted,
name,
runtimeevents.Scope{Channel: name},
runtimeevents.SeverityInfo,
ChannelLifecyclePayload{Type: channelType},
)
deferFuncs = append(deferFuncs, func() {
m.RegisterChannel(name, channel)
})
@@ -1510,6 +1648,7 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
if wExists && w != nil {
select {
case w.queue <- msg:
m.publishOutboundQueued(channelName, msg)
return nil
case <-ctx.Done():
return ctx.Err()
+130
View File
@@ -14,6 +14,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -242,6 +243,57 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
}
}
func TestStartAllPublishesLifecycleRuntimeEvents(t *testing.T) {
eventBus := runtimeevents.NewBus()
defer func() {
if err := eventBus.Close(); err != nil {
t.Errorf("event bus close failed: %v", err)
}
}()
_, eventsCh, err := eventBus.Channel().SubscribeChan(
t.Context(),
runtimeevents.SubscribeOptions{Name: "channel-lifecycle", Buffer: 4},
)
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
m := newTestManager()
m.runtimeEvents = eventBus
m.config = &config.Config{Channels: config.ChannelsConfig{}}
m.channels["good"] = &mockChannel{}
m.channels["bad"] = &mockChannel{
startFn: func(_ context.Context) error { return errors.New("bad start") },
}
if err := m.StartAll(t.Context()); err != nil {
t.Fatalf("StartAll() error = %v", err)
}
t.Cleanup(func() {
stopCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := m.StopAll(stopCtx); err != nil {
t.Errorf("StopAll() error = %v", err)
}
})
events := []runtimeevents.Event{
receiveChannelRuntimeEvent(t, eventsCh),
receiveChannelRuntimeEvent(t, eventsCh),
}
seen := map[runtimeevents.Kind]runtimeevents.Event{}
for _, evt := range events {
seen[evt.Kind] = evt
}
if evt, ok := seen[runtimeevents.KindChannelLifecycleStarted]; !ok || evt.Scope.Channel != "good" {
t.Fatalf("missing started event for good channel: %+v", events)
}
if evt, ok := seen[runtimeevents.KindChannelLifecycleStartFailed]; !ok || evt.Scope.Channel != "bad" {
t.Fatalf("missing failed event for bad channel: %+v", events)
}
}
func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage {
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID)
@@ -256,6 +308,21 @@ func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMes
return bus.NormalizeOutboundMediaMessage(msg)
}
func receiveChannelRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
t.Helper()
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("runtime event channel closed before expected event")
}
return evt
case <-time.After(time.Second):
t.Fatal("timed out waiting for runtime event")
return runtimeevents.Event{}
}
}
func TestSendWithRetry_Success(t *testing.T) {
m := newTestManager()
var callCount int
@@ -280,6 +347,69 @@ func TestSendWithRetry_Success(t *testing.T) {
}
}
func TestSendWithRetryPublishesOutboundRuntimeEvents(t *testing.T) {
eventBus := runtimeevents.NewBus()
defer func() {
if err := eventBus.Close(); err != nil {
t.Errorf("event bus close failed: %v", err)
}
}()
_, eventsCh, err := eventBus.Channel().OfKind(
runtimeevents.KindChannelMessageOutboundSent,
runtimeevents.KindChannelMessageOutboundFailed,
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "channel-outbound", Buffer: 2})
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
m := newTestManager()
m.runtimeEvents = eventBus
successWorker := &channelWorker{
ch: &mockChannel{},
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.sendWithRetry(
context.Background(),
"test",
successWorker,
testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat-1", Content: "hello"}),
)
sent := receiveChannelRuntimeEvent(t, eventsCh)
if sent.Kind != runtimeevents.KindChannelMessageOutboundSent || sent.Scope.ChatID != "chat-1" {
t.Fatalf("sent event = %+v", sent)
}
if sent.Attrs["content_len"] != 5 {
t.Fatalf("sent attrs = %#v, want content_len", sent.Attrs)
}
failWorker := &channelWorker{
ch: &mockChannel{
sendFn: func(context.Context, bus.OutboundMessage) error {
return fmt.Errorf("send failed: %w", ErrSendFailed)
},
},
limiter: rate.NewLimiter(rate.Inf, 1),
}
m.sendWithRetry(
context.Background(),
"test",
failWorker,
testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat-2", Content: "hello"}),
)
failed := receiveChannelRuntimeEvent(t, eventsCh)
if failed.Kind != runtimeevents.KindChannelMessageOutboundFailed || failed.Scope.ChatID != "chat-2" {
t.Fatalf("failed event = %+v", failed)
}
if failed.Severity != runtimeevents.SeverityError {
t.Fatalf("failed severity = %q", failed.Severity)
}
if failed.Attrs["error"] == "" || failed.Attrs["retries"] != maxRetries {
t.Fatalf("failed attrs = %#v, want error and retries", failed.Attrs)
}
}
func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
m := newTestManager()
var callCount int
+48 -5
View File
@@ -333,18 +333,33 @@ func TestIsThoughtPayload(t *testing.T) {
want bool
}{
{
name: "explicit thought bool",
name: "explicit thought kind",
payload: map[string]any{PayloadKeyKind: MessageKindThought},
want: true,
},
{
name: "thought kind ignores case and whitespace",
payload: map[string]any{PayloadKeyKind: " ThOuGhT "},
want: true,
},
{
name: "legacy thought bool remains supported for inbound compatibility",
payload: map[string]any{PayloadKeyThought: true},
want: true,
},
{
name: "thought false",
name: "legacy thought false",
payload: map[string]any{PayloadKeyThought: false},
want: false,
},
{
name: "thought string ignored",
payload: map[string]any{PayloadKeyThought: "true"},
name: "tool calls kind",
payload: map[string]any{PayloadKeyKind: MessageKindToolCalls},
want: false,
},
{
name: "non-string kind ignored",
payload: map[string]any{PayloadKeyKind: true},
want: false,
},
{
@@ -380,7 +395,7 @@ func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) {
Type: TypeMessageCreate,
Payload: map[string]any{
PayloadKeyContent: "internal reasoning",
PayloadKeyThought: true,
PayloadKeyKind: MessageKindThought,
},
})
@@ -390,3 +405,31 @@ func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) {
case <-time.After(150 * time.Millisecond):
}
}
func TestPicoClientChannel_HandleServerMessage_IgnoresLegacyThoughtBool(t *testing.T) {
mb := bus.NewMessageBus()
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
URL: "ws://localhost:8080/ws",
}, mb)
if err != nil {
t.Fatalf("NewPicoClientChannel() error = %v", err)
}
ch.ctx = context.Background()
pc := &picoConn{sessionID: "sess-thought-legacy"}
ch.handleServerMessage(pc, PicoMessage{
Type: TypeMessageCreate,
Payload: map[string]any{
PayloadKeyContent: "legacy internal reasoning",
PayloadKeyThought: true,
},
})
select {
case msg := <-mb.InboundChan():
t.Fatalf("expected no inbound publish for legacy thought payload, got %+v", msg)
case <-time.After(150 * time.Millisecond):
}
}
+10 -3
View File
@@ -323,10 +323,18 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
payload := map[string]any{
PayloadKeyContent: content,
PayloadKeyThought: isThought,
"message_id": msgID,
}
if isToolCalls {
switch {
case isThought:
payload[PayloadKeyKind] = MessageKindThought
// This field is kept solely for compatibility with legacy pico clients that
// do not yet support the newer "kind" field.
// DO NOT use it for any purpose other than legacy client compatibility.
payload[PayloadKeyThought] = true
case isToolCalls:
payload[PayloadKeyKind] = MessageKindToolCalls
if toolCalls, ok := picoToolCallsPayload(msg); ok {
payload[PayloadKeyToolCalls] = toolCalls
@@ -457,7 +465,6 @@ func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (strin
msgID := uuid.New().String()
outMsg := newMessage(TypeMessageCreate, map[string]any{
PayloadKeyContent: text,
PayloadKeyThought: false,
"message_id": msgID,
})
+43 -2
View File
@@ -131,8 +131,8 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
if got := payload[PayloadKeyContent]; got != "thinking trace" {
t.Fatalf("thought content = %#v, want %q", got, "thinking trace")
}
if got := payload[PayloadKeyThought]; got != true {
t.Fatalf("thought flag = %#v, want true", got)
if got := payload[PayloadKeyKind]; got != MessageKindThought {
t.Fatalf("thought kind = %#v, want %q", got, MessageKindThought)
}
if got := payload["message_id"]; got == "msg-progress" || got == nil || got == "" {
t.Fatalf("thought message_id = %#v, want new non-progress id", got)
@@ -193,6 +193,47 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
}
}
func TestSendPlaceholder_EmitsNormalMessageWithoutKind(t *testing.T) {
ch := newTestPicoChannel(t)
ch.bc.Placeholder.Enabled = true
if err := ch.Start(context.Background()); err != nil {
t.Fatalf("Start() error = %v", err)
}
defer ch.Stop(context.Background())
clientConn, received, cleanup := newTestPicoWebSocket(t)
defer cleanup()
ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"})
msgID, err := ch.SendPlaceholder(context.Background(), "pico:sess-1")
if err != nil {
t.Fatalf("SendPlaceholder() error = %v", err)
}
if msgID == "" {
t.Fatal("expected placeholder message id")
}
select {
case msg := <-received:
if msg.Type != TypeMessageCreate {
t.Fatalf("placeholder message type = %q, want %q", msg.Type, TypeMessageCreate)
}
payload := msg.Payload
if got := payload["message_id"]; got != msgID {
t.Fatalf("placeholder message_id = %#v, want %q", got, msgID)
}
if got := payload[PayloadKeyContent]; got != "Thinking..." {
t.Fatalf("placeholder content = %#v, want %q", got, "Thinking...")
}
if got, ok := payload[PayloadKeyKind]; ok {
t.Fatalf("placeholder kind = %#v, want absent", got)
}
case <-time.After(time.Second):
t.Fatal("expected placeholder message to be delivered")
}
}
func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) {
ch := newTestPicoChannel(t)
+11 -1
View File
@@ -1,6 +1,9 @@
package pico
import "time"
import (
"strings"
"time"
)
// Protocol message types.
const (
@@ -47,6 +50,13 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
}
func isThoughtPayload(payload map[string]any) bool {
kind, _ := payload[PayloadKeyKind].(string)
if strings.EqualFold(strings.TrimSpace(kind), MessageKindThought) {
return true
}
// Keep pico_client inbound-compatible with legacy servers that still send
// the pre-kind boolean thought marker.
thought, _ := payload[PayloadKeyThought].(bool)
return thought
}
+1
View File
@@ -8,6 +8,7 @@ func BuiltinDefinitions() []Definition {
return []Definition{
startCommand(),
helpCommand(),
stopCommand(),
showCommand(),
listCommand(),
useCommand(),
+56
View File
@@ -42,6 +42,9 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
if !strings.Contains(reply, "/list [models|channels|agents|skills|mcp]") {
t.Fatalf("/help reply missing /list usage, got %q", reply)
}
if !strings.Contains(reply, "/stop") {
t.Fatalf("/help reply missing /stop usage, got %q", reply)
}
if !strings.Contains(reply, "/use <skill> <message>") {
if !strings.Contains(reply, "/use <skill> [message]") {
t.Fatalf("/help reply missing /use usage, got %q", reply)
@@ -49,6 +52,59 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
}
}
func TestBuiltinStop_UsesRuntimeStopper(t *testing.T) {
rt := &Runtime{
StopActiveTurn: func() (StopResult, error) {
return StopResult{
Stopped: true,
TaskName: "sync the long running job",
}, nil
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/stop",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Task stopped. \"sync the long running job\" was canceled." {
t.Fatalf("/stop reply=%q", reply)
}
}
func TestBuiltinStop_NoActiveTask(t *testing.T) {
rt := &Runtime{
StopActiveTurn: func() (StopResult, error) {
return StopResult{}, nil
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/stop",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "No active task to stop." {
t.Fatalf("/stop reply=%q, want no-active message", reply)
}
}
func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) {
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), nil)
+52
View File
@@ -0,0 +1,52 @@
package commands
import (
"context"
"fmt"
"strings"
)
func stopCommand() Definition {
return Definition{
Name: "stop",
Description: "Stop the current task",
Usage: "/stop",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.StopActiveTurn == nil {
return req.Reply(unavailableMsg)
}
result, err := rt.StopActiveTurn()
if err != nil {
return req.Reply("Failed to stop task: " + err.Error())
}
return req.Reply(FormatStopReply(result))
},
}
}
// FormatStopReply renders a user-facing reply for a stop request.
func FormatStopReply(result StopResult) string {
if !result.Stopped {
return "No active task to stop."
}
taskName := compactStopTaskName(result.TaskName)
if taskName == "" {
return "Task stopped. Current task was canceled."
}
return fmt.Sprintf("Task stopped. %q was canceled.", taskName)
}
func compactStopTaskName(taskName string) string {
taskName = strings.Join(strings.Fields(strings.TrimSpace(taskName)), " ")
if taskName == "" {
return ""
}
if len(taskName) > 80 {
return taskName[:77] + "..."
}
return taskName
}
+7
View File
@@ -36,6 +36,12 @@ type ContextStats struct {
MessageCount int
}
// StopResult describes the outcome of a stop request for the current session.
type StopResult struct {
Stopped bool
TaskName string
}
// Runtime provides runtime dependencies to command handlers. It is constructed
// per-request by the agent loop so that per-request state (like session scope)
// can coexist with long-lived callbacks (like GetModelInfo).
@@ -55,4 +61,5 @@ type Runtime struct {
SwitchChannel func(value string) error
ClearHistory func() error
ReloadConfig func() error
StopActiveTurn func() (StopResult, error)
}
+52 -39
View File
@@ -17,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/logger"
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
)
// rrCounter is a global counter for round-robin load balancing across models.
@@ -39,6 +40,7 @@ type Config struct {
Channels ChannelsConfig `json:"channel_list" yaml:"channel_list"`
ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration
Gateway GatewayConfig `json:"gateway" yaml:"-"`
Events EventsConfig `json:"events,omitempty" yaml:"-"`
Hooks HooksConfig `json:"hooks,omitempty" yaml:"-"`
Tools ToolsConfig `json:"tools" yaml:",inline"`
Heartbeat HeartbeatConfig `json:"heartbeat" yaml:"-"`
@@ -276,6 +278,8 @@ type AgentDefaults struct {
SplitOnMarker bool `json:"split_on_marker" env:"PICOCLAW_AGENTS_DEFAULTS_SPLIT_ON_MARKER"` // split messages on <|[SPLIT]|> marker
ContextManager string `json:"context_manager,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_MANAGER"`
ContextManagerConfig json.RawMessage `json:"context_manager_config,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_MANAGER_CONFIG"`
MaxLLMRetries int `json:"max_llm_retries,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_LLM_RETRIES"`
LLMRetryBackoffSecs int `json:"llm_retry_backoff_secs,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_LLM_RETRY_BACKOFF_SECS"`
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
@@ -553,12 +557,13 @@ type ModelConfig struct {
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
// Optional optimizations
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
ToolSchemaTransform string `json:"tool_schema_transform,omitempty"` // Optional tool schema compatibility transform (e.g. "simple")
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
@@ -595,6 +600,9 @@ func (c *ModelConfig) Validate() error {
if c.Model == "" {
return fmt.Errorf("model is required")
}
if _, err := providercommon.NormalizeToolSchemaTransform(c.ToolSchemaTransform); err != nil {
return err
}
return nil
}
@@ -823,6 +831,7 @@ type ToolsConfig struct {
ListDir ToolConfig `json:"list_dir" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
Serial ToolConfig `json:"serial" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SERIAL_"`
SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"`
Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
@@ -1465,23 +1474,24 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
// Create a copy for the additional key
additionalEntry := &ModelConfig{
ModelName: expandedName,
Provider: m.Provider,
Model: m.Model,
APIBase: m.APIBase,
APIKeys: SimpleSecureStrings(keys[i]),
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
isVirtual: true,
ModelName: expandedName,
Provider: m.Provider,
Model: m.Model,
APIBase: m.APIBase,
APIKeys: SimpleSecureStrings(keys[i]),
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ToolSchemaTransform: m.ToolSchemaTransform,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
isVirtual: true,
}
expanded = append(expanded, additionalEntry)
fallbackNames = append(fallbackNames, expandedName)
@@ -1489,22 +1499,23 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
// Create the primary entry with first key and fallbacks
primaryEntry := &ModelConfig{
ModelName: originalName,
Provider: m.Provider,
Model: m.Model,
APIBase: m.APIBase,
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
APIKeys: SimpleSecureStrings(keys[0]),
ModelName: originalName,
Provider: m.Provider,
Model: m.Model,
APIBase: m.APIBase,
Proxy: m.Proxy,
AuthMethod: m.AuthMethod,
ConnectMode: m.ConnectMode,
Workspace: m.Workspace,
RPM: m.RPM,
MaxTokensField: m.MaxTokensField,
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ToolSchemaTransform: m.ToolSchemaTransform,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
UserAgent: m.UserAgent,
APIKeys: SimpleSecureStrings(keys[0]),
}
// Prepend new fallbacks to existing ones
@@ -1548,6 +1559,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
return t.Message.Enabled
case "read_file":
return t.ReadFile.Enabled
case "serial":
return t.Serial.Enabled
case "spawn":
return t.Spawn.Enabled
case "spawn_status":
+30
View File
@@ -1998,6 +1998,36 @@ func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) {
}
}
func TestModelConfig_ToolSchemaTransformRoundTrip(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
cfg := &Config{
Version: CurrentVersion,
ModelList: []*ModelConfig{
{
ModelName: "test-model",
Model: "openai/test",
APIKeys: SimpleSecureStrings("sk-test"),
ToolSchemaTransform: "simple",
},
},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig error: %v", err)
}
loaded, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig error: %v", err)
}
if got := loaded.ModelList[0].ToolSchemaTransform; got != "simple" {
t.Fatalf("ToolSchemaTransform = %q, want %q", got, "simple")
}
}
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
cfg := DefaultConfig()
+9 -1
View File
@@ -39,7 +39,9 @@ func DefaultConfig() *Config {
MaxArgsLength: 300,
SeparateMessages: false,
},
SplitOnMarker: false,
SplitOnMarker: false,
MaxLLMRetries: 2,
LLMRetryBackoffSecs: 2,
},
},
Session: SessionConfig{
@@ -294,6 +296,9 @@ func DefaultConfig() *Config {
HotReload: false,
LogLevel: DefaultGatewayLogLevel,
},
Events: EventsConfig{
Logging: defaultEventLoggingConfig(),
},
Tools: ToolsConfig{
FilterSensitiveData: true,
FilterMinLength: 8,
@@ -435,6 +440,9 @@ func DefaultConfig() *Config {
Mode: ReadFileModeBytes,
MaxReadFileSize: 64 * 1024, // 64KB
},
Serial: ToolConfig{
Enabled: false, // Hardware tool - requires host serial ports
},
Spawn: ToolConfig{
Enabled: true,
},
+48
View File
@@ -0,0 +1,48 @@
package config
// EventsConfig groups runtime event configuration.
type EventsConfig struct {
Logging EventLoggingConfig `json:"logging,omitempty" envPrefix:"PICOCLAW_EVENTS_LOGGING_"`
}
// EventLoggingConfig controls centralized runtime event logging.
type EventLoggingConfig struct {
// Enabled controls whether runtime events are printed by the built-in logger.
Enabled bool `json:"enabled" env:"ENABLED"`
// Include contains exact event kinds or glob patterns such as "agent.*" or "*".
Include []string `json:"include,omitempty" env:"INCLUDE"`
// Exclude contains exact event kinds or glob patterns to suppress after Include matches.
Exclude []string `json:"exclude,omitempty" env:"EXCLUDE"`
// MinSeverity filters out events below the configured severity: debug, info, warn, or error.
MinSeverity string `json:"min_severity,omitempty" env:"MIN_SEVERITY"`
// IncludePayload adds the raw payload to logs. Leave disabled unless detailed diagnostics are needed.
IncludePayload bool `json:"include_payload,omitempty" env:"INCLUDE_PAYLOAD"`
}
// DefaultEventLoggingInclude keeps the pre-existing behavior where agent events
// are printed, while non-agent runtime events are published for subscribers only.
var DefaultEventLoggingInclude = []string{"agent.*"}
// EffectiveEventLoggingConfig returns a logging config with stable defaults.
func EffectiveEventLoggingConfig(cfg *Config) EventLoggingConfig {
if cfg == nil {
return defaultEventLoggingConfig()
}
out := cfg.Events.Logging
if out.MinSeverity == "" {
out.MinSeverity = "info"
}
if len(out.Include) == 0 {
out.Include = append([]string(nil), DefaultEventLoggingInclude...)
}
return out
}
func defaultEventLoggingConfig() EventLoggingConfig {
return EventLoggingConfig{
Enabled: true,
Include: append([]string(nil), DefaultEventLoggingInclude...),
MinSeverity: "info",
}
}
+103
View File
@@ -0,0 +1,103 @@
package config
import (
"os"
"path/filepath"
"reflect"
"testing"
)
func TestDefaultEventLoggingConfig(t *testing.T) {
cfg := DefaultConfig()
logCfg := EffectiveEventLoggingConfig(cfg)
if !logCfg.Enabled {
t.Fatal("default event logging should be enabled")
}
if !reflect.DeepEqual(logCfg.Include, []string{"agent.*"}) {
t.Fatalf("default include = %#v, want agent.*", logCfg.Include)
}
if logCfg.MinSeverity != "info" {
t.Fatalf("default min severity = %q, want info", logCfg.MinSeverity)
}
if logCfg.IncludePayload {
t.Fatal("default event logging should not include raw payloads")
}
}
func TestLoadConfigEventLoggingOverrides(t *testing.T) {
path := filepath.Join(t.TempDir(), "config.json")
data := []byte(`{
"version": 3,
"events": {
"logging": {
"enabled": false,
"include": ["gateway.*"],
"exclude": ["gateway.ready"],
"min_severity": "warn",
"include_payload": true
}
}
}`)
if err := os.WriteFile(path, data, 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := LoadConfig(path)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
logCfg := EffectiveEventLoggingConfig(cfg)
if logCfg.Enabled {
t.Fatal("loaded event logging enabled = true, want false")
}
if !reflect.DeepEqual(logCfg.Include, []string{"gateway.*"}) {
t.Fatalf("loaded include = %#v, want gateway.*", logCfg.Include)
}
if !reflect.DeepEqual(logCfg.Exclude, []string{"gateway.ready"}) {
t.Fatalf("loaded exclude = %#v, want gateway.ready", logCfg.Exclude)
}
if logCfg.MinSeverity != "warn" {
t.Fatalf("loaded min severity = %q, want warn", logCfg.MinSeverity)
}
if !logCfg.IncludePayload {
t.Fatal("loaded include_payload = false, want true")
}
}
func TestLoadConfigEventLoggingEnvOverrides(t *testing.T) {
t.Setenv("PICOCLAW_EVENTS_LOGGING_ENABLED", "false")
t.Setenv("PICOCLAW_EVENTS_LOGGING_INCLUDE", "gateway.*,channel.lifecycle.*")
t.Setenv("PICOCLAW_EVENTS_LOGGING_EXCLUDE", "gateway.ready")
t.Setenv("PICOCLAW_EVENTS_LOGGING_MIN_SEVERITY", "error")
t.Setenv("PICOCLAW_EVENTS_LOGGING_INCLUDE_PAYLOAD", "true")
path := filepath.Join(t.TempDir(), "config.json")
data := []byte(`{"version": 3}`)
if err := os.WriteFile(path, data, 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := LoadConfig(path)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
logCfg := EffectiveEventLoggingConfig(cfg)
if logCfg.Enabled {
t.Fatal("env enabled override = true, want false")
}
if !reflect.DeepEqual(logCfg.Include, []string{"gateway.*", "channel.lifecycle.*"}) {
t.Fatalf("env include = %#v, want gateway/channel lifecycle", logCfg.Include)
}
if !reflect.DeepEqual(logCfg.Exclude, []string{"gateway.ready"}) {
t.Fatalf("env exclude = %#v, want gateway.ready", logCfg.Exclude)
}
if logCfg.MinSeverity != "error" {
t.Fatalf("env min severity = %q, want error", logCfg.MinSeverity)
}
if !logCfg.IncludePayload {
t.Fatal("env include_payload = false, want true")
}
}
+18
View File
@@ -158,6 +158,15 @@ func TestModelConfig_Validate(t *testing.T) {
},
wantErr: false,
},
{
name: "valid tool schema transform",
config: ModelConfig{
ModelName: "test",
Model: "openai/gpt-4o",
ToolSchemaTransform: "simple",
},
wantErr: false,
},
{
name: "missing model_name",
config: ModelConfig{
@@ -177,6 +186,15 @@ func TestModelConfig_Validate(t *testing.T) {
config: ModelConfig{},
wantErr: true,
},
{
name: "invalid tool schema transform",
config: ModelConfig{
ModelName: "test",
Model: "openai/gpt-4o",
ToolSchemaTransform: "invalid",
},
wantErr: true,
},
}
for _, tt := range tests {
+16 -9
View File
@@ -187,15 +187,16 @@ func TestExpandMultiKeyModels_Deduplication(t *testing.T) {
func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
modelCfg := &ModelConfig{
ModelName: "gpt-4",
Provider: "openrouter",
Model: "openai/gpt-4o",
APIBase: "https://api.example.com",
Proxy: "http://proxy:8080",
RPM: 60,
MaxTokensField: "max_completion_tokens",
RequestTimeout: 30,
ThinkingLevel: "high",
ModelName: "gpt-4",
Provider: "openrouter",
Model: "openai/gpt-4o",
APIBase: "https://api.example.com",
Proxy: "http://proxy:8080",
RPM: 60,
MaxTokensField: "max_completion_tokens",
RequestTimeout: 30,
ThinkingLevel: "high",
ToolSchemaTransform: "simple",
}
modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing
models := []*ModelConfig{modelCfg}
@@ -225,6 +226,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
if primary.ThinkingLevel != "high" {
t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel)
}
if primary.ToolSchemaTransform != "simple" {
t.Errorf("expected tool_schema_transform preserved, got %q", primary.ToolSchemaTransform)
}
// Check additional entry also preserves fields
additional := result[0]
@@ -237,6 +241,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
if additional.RPM != 60 {
t.Errorf("expected additional rpm preserved, got %d", additional.RPM)
}
if additional.ToolSchemaTransform != "simple" {
t.Errorf("expected additional tool_schema_transform preserved, got %q", additional.ToolSchemaTransform)
}
}
func TestExpandMultiKeyModels_IsVirtualFlag(t *testing.T) {
+243
View File
@@ -0,0 +1,243 @@
package events
import (
"context"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"
)
var globalEventSeq atomic.Uint64
// Bus publishes runtime events and creates filtered channels.
type Bus interface {
Publish(ctx context.Context, evt Event) PublishResult
PublishNonBlocking(evt Event) PublishResult
Channel() EventChannel
Close() error
Stats() Stats
}
// PublishResult reports per-publish delivery outcomes.
type PublishResult struct {
Matched int
Delivered int
Dropped int
Blocked int
Closed bool
}
// EventBus is an in-process runtime event broadcaster.
type EventBus struct {
mu sync.RWMutex
subs map[uint64]*eventSubscription
orderedSubs []*eventSubscription
closed bool
nextSubID atomic.Uint64
published atomic.Uint64
matched atomic.Uint64
delivered atomic.Uint64
dropped atomic.Uint64
blocked atomic.Uint64
}
var _ Bus = (*EventBus)(nil)
// NewBus creates an in-process runtime event bus.
func NewBus() *EventBus {
return &EventBus{
subs: make(map[uint64]*eventSubscription),
}
}
// Publish broadcasts evt to subscriptions whose filters match it.
func (b *EventBus) Publish(ctx context.Context, evt Event) PublishResult {
return b.publish(ctx, evt, false)
}
// PublishNonBlocking broadcasts evt without waiting for subscriber queue capacity.
func (b *EventBus) PublishNonBlocking(evt Event) PublishResult {
return b.publish(context.Background(), evt, true)
}
func (b *EventBus) publish(ctx context.Context, evt Event, nonBlocking bool) PublishResult {
if b == nil {
return PublishResult{Closed: true}
}
if ctx == nil {
ctx = context.Background()
}
if evt.Time.IsZero() {
evt.Time = time.Now()
}
if evt.ID == "" {
evt.ID = nextEventID()
}
subs, closed := b.snapshotSubscribers()
if closed {
return PublishResult{Closed: true}
}
b.published.Add(1)
result := PublishResult{}
for _, sub := range subs {
if !matchesFilters(sub.filters, evt) {
continue
}
result.Matched++
b.matched.Add(1)
delivery := sub.enqueue(ctx, evt, nonBlocking)
if delivery.closed {
continue
}
result.Delivered += delivery.delivered
result.Dropped += delivery.dropped
result.Blocked += delivery.blocked
b.delivered.Add(uint64(delivery.delivered))
b.dropped.Add(uint64(delivery.dropped))
b.blocked.Add(uint64(delivery.blocked))
}
return result
}
// Channel returns the root event channel for this bus.
func (b *EventBus) Channel() EventChannel {
return eventChannel{bus: b}
}
// Close closes the bus and all active subscriptions.
func (b *EventBus) Close() error {
if b == nil {
return nil
}
b.mu.Lock()
if b.closed {
b.mu.Unlock()
return nil
}
b.closed = true
subs := b.orderedSubs
b.subs = nil
b.orderedSubs = nil
b.mu.Unlock()
for _, sub := range subs {
sub.closeInput()
}
return nil
}
// Stats returns a snapshot of bus and subscription counters.
func (b *EventBus) Stats() Stats {
if b == nil {
return Stats{Closed: true}
}
b.mu.RLock()
closed := b.closed
subs := b.orderedSubs
b.mu.RUnlock()
stats := Stats{
Published: b.published.Load(),
Matched: b.matched.Load(),
Delivered: b.delivered.Load(),
Dropped: b.dropped.Load(),
Blocked: b.blocked.Load(),
Closed: closed,
Subscribers: len(subs),
SubscriberStats: make([]SubscriberStats, 0, len(subs)),
}
for _, sub := range subs {
stats.SubscriberStats = append(stats.SubscriberStats, sub.Stats())
}
return stats
}
func (b *EventBus) subscribe(
ctx context.Context,
filters []Filter,
opts SubscribeOptions,
handler Handler,
once bool,
) (Subscription, error) {
if b == nil {
return nil, ErrBusClosed
}
id := b.nextSubID.Add(1)
sub := newSubscription(b, id, filters, opts, handler, once)
b.mu.Lock()
if b.closed {
b.mu.Unlock()
sub.closeInput()
return nil, ErrBusClosed
}
b.subs[id] = sub
b.rebuildOrderedSubscribersLocked()
b.mu.Unlock()
if handler != nil {
go sub.run(ctx)
}
sub.watchContext(ctx)
return sub, nil
}
func (b *EventBus) unsubscribe(id uint64) {
b.mu.Lock()
sub, ok := b.subs[id]
if ok {
delete(b.subs, id)
b.rebuildOrderedSubscribersLocked()
}
b.mu.Unlock()
if ok {
sub.closeInput()
}
}
func (b *EventBus) snapshotSubscribers() ([]*eventSubscription, bool) {
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return nil, true
}
return b.orderedSubs, false
}
func (b *EventBus) rebuildOrderedSubscribersLocked() {
subs := make([]*eventSubscription, 0, len(b.subs))
for _, sub := range b.subs {
subs = append(subs, sub)
}
sortSubscriptions(subs)
b.orderedSubs = subs
}
func sortSubscriptions(subs []*eventSubscription) {
sort.Slice(subs, func(i, j int) bool {
if subs[i].opts.Priority == subs[j].opts.Priority {
return subs[i].id < subs[j].id
}
return subs[i].opts.Priority > subs[j].opts.Priority
})
}
func nextEventID() string {
id := globalEventSeq.Add(1)
return "evt-" + strconv.FormatUint(id, 10)
}
+75
View File
@@ -0,0 +1,75 @@
package events
import "context"
// EventChannel is a filtered view over an EventBus.
type EventChannel interface {
Filter(filter Filter) EventChannel
OfKind(kinds ...Kind) EventChannel
KindPrefix(prefix string) EventChannel
Source(component string, names ...string) EventChannel
Scope(scope ScopeFilter) EventChannel
Subscribe(ctx context.Context, opts SubscribeOptions, handler Handler) (Subscription, error)
SubscribeChan(ctx context.Context, opts SubscribeOptions) (Subscription, <-chan Event, error)
SubscribeOnce(ctx context.Context, opts SubscribeOptions, handler Handler) (Subscription, error)
}
type eventChannel struct {
bus *EventBus
filters []Filter
}
// Filter returns a new EventChannel with filter appended.
func (c eventChannel) Filter(filter Filter) EventChannel {
filters := append([]Filter(nil), c.filters...)
if filter != nil {
filters = append(filters, filter)
}
return eventChannel{bus: c.bus, filters: filters}
}
// OfKind returns a new EventChannel matching any of kinds.
func (c eventChannel) OfKind(kinds ...Kind) EventChannel {
return c.Filter(MatchKind(kinds...))
}
// KindPrefix returns a new EventChannel matching events with the kind prefix.
func (c eventChannel) KindPrefix(prefix string) EventChannel {
return c.Filter(MatchKindPrefix(prefix))
}
// Source returns a new EventChannel matching source component and optional names.
func (c eventChannel) Source(component string, names ...string) EventChannel {
return c.Filter(MatchSource(component, names...))
}
// Scope returns a new EventChannel matching non-empty scope fields.
func (c eventChannel) Scope(scope ScopeFilter) EventChannel {
return c.Filter(MatchScope(scope))
}
// Subscribe registers handler for events matching this channel.
func (c eventChannel) Subscribe(ctx context.Context, opts SubscribeOptions, handler Handler) (Subscription, error) {
if handler == nil {
return nil, ErrNilHandler
}
return c.bus.subscribe(ctx, c.filters, opts, handler, false)
}
// SubscribeChan registers a channel subscription for events matching this channel.
func (c eventChannel) SubscribeChan(ctx context.Context, opts SubscribeOptions) (Subscription, <-chan Event, error) {
sub, err := c.bus.subscribe(ctx, c.filters, opts, nil, false)
if err != nil {
return nil, nil, err
}
return sub, sub.(*eventSubscription).ch, nil
}
// SubscribeOnce registers handler and closes the subscription after the first event.
func (c eventChannel) SubscribeOnce(ctx context.Context, opts SubscribeOptions, handler Handler) (Subscription, error) {
if handler == nil {
return nil, ErrNilHandler
}
return c.bus.subscribe(ctx, c.filters, opts, handler, true)
}
+3
View File
@@ -0,0 +1,3 @@
// Package events provides the process-local runtime event bus used to observe
// PicoClaw components without coupling them to agent-specific event envelopes.
package events
+254
View File
@@ -0,0 +1,254 @@
package events
import (
"context"
"testing"
"time"
)
func TestPublishDeliversToMatchingSubscriber(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
_, ch, err := bus.Channel().OfKind(KindAgentTurnStart).SubscribeChan(
context.Background(),
SubscribeOptions{Name: "turn-starts", Buffer: 1},
)
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
unmatched := bus.Publish(context.Background(), Event{Kind: KindAgentTurnEnd})
if unmatched.Matched != 0 || unmatched.Delivered != 0 {
t.Fatalf("unmatched Publish = %+v, want no delivery", unmatched)
}
result := bus.Publish(context.Background(), Event{Kind: KindAgentTurnStart})
if result.Matched != 1 || result.Delivered != 1 || result.Dropped != 0 {
t.Fatalf("Publish = %+v, want one delivered event", result)
}
evt := receiveEvent(t, ch)
if evt.Kind != KindAgentTurnStart {
t.Fatalf("event kind = %q, want %q", evt.Kind, KindAgentTurnStart)
}
if evt.ID == "" {
t.Fatal("event ID is empty")
}
if evt.Time.IsZero() {
t.Fatal("event Time is zero")
}
}
func TestDropNewestIncrementsStats(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
sub, _, err := bus.Channel().SubscribeChan(
context.Background(),
SubscribeOptions{Name: "drop-newest", Buffer: 1, Backpressure: DropNewest},
)
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
first := bus.Publish(context.Background(), Event{Kind: KindAgentTurnStart})
if first.Delivered != 1 || first.Dropped != 0 {
t.Fatalf("first Publish = %+v, want one delivered event", first)
}
second := bus.Publish(context.Background(), Event{Kind: KindAgentTurnEnd})
if second.Delivered != 0 || second.Dropped != 1 {
t.Fatalf("second Publish = %+v, want one dropped event", second)
}
if got := sub.Stats().Dropped; got != 1 {
t.Fatalf("subscription dropped = %d, want 1", got)
}
if got := bus.Stats().Dropped; got != 1 {
t.Fatalf("bus dropped = %d, want 1", got)
}
}
func TestDropOldestKeepsNewestEvent(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
sub, ch, err := bus.Channel().SubscribeChan(
context.Background(),
SubscribeOptions{Name: "drop-oldest", Buffer: 1, Backpressure: DropOldest},
)
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
bus.Publish(context.Background(), Event{Kind: Kind("test.old"), Payload: "old"})
result := bus.Publish(context.Background(), Event{Kind: Kind("test.new"), Payload: "new"})
if result.Delivered != 1 || result.Dropped != 1 {
t.Fatalf("Publish = %+v, want replacement delivery", result)
}
evt := receiveEvent(t, ch)
if evt.Payload != "new" {
t.Fatalf("payload = %v, want new", evt.Payload)
}
if got := sub.Stats().Dropped; got != 1 {
t.Fatalf("subscription dropped = %d, want 1", got)
}
}
func TestBlockRespectsContext(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
_, _, err := bus.Channel().SubscribeChan(
context.Background(),
SubscribeOptions{Name: "block", Buffer: 1, Backpressure: Block},
)
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
first := bus.Publish(context.Background(), Event{Kind: Kind("test.first")})
if first.Delivered != 1 {
t.Fatalf("first Publish = %+v, want one delivered event", first)
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
second := bus.Publish(ctx, Event{Kind: Kind("test.second")})
if second.Blocked != 1 || second.Dropped != 1 || second.Delivered != 0 {
t.Fatalf("second Publish = %+v, want one blocked drop", second)
}
}
func TestPublishNonBlockingDropsForFullBlockSubscriber(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
sub, _, err := bus.Channel().SubscribeChan(
context.Background(),
SubscribeOptions{Name: "block", Buffer: 1, Backpressure: Block},
)
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
first := bus.PublishNonBlocking(Event{Kind: Kind("test.first")})
if first.Delivered != 1 {
t.Fatalf("first PublishNonBlocking = %+v, want one delivered event", first)
}
resultCh := make(chan PublishResult, 1)
go func() {
resultCh <- bus.PublishNonBlocking(Event{Kind: Kind("test.second")})
}()
select {
case second := <-resultCh:
if second.Matched != 1 || second.Delivered != 0 || second.Dropped != 1 || second.Blocked != 0 {
t.Fatalf("second PublishNonBlocking = %+v, want non-blocking drop", second)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("PublishNonBlocking blocked on full Block subscriber")
}
if got := sub.Stats().Dropped; got != 1 {
t.Fatalf("subscription dropped = %d, want 1", got)
}
}
func TestStatsSubscribersKeepPriorityOrder(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
low, _, err := bus.Channel().SubscribeChan(
context.Background(),
SubscribeOptions{Name: "low", Priority: -1},
)
if err != nil {
t.Fatalf("SubscribeChan low failed: %v", err)
}
high, _, err := bus.Channel().SubscribeChan(
context.Background(),
SubscribeOptions{Name: "high", Priority: 10},
)
if err != nil {
t.Fatalf("SubscribeChan high failed: %v", err)
}
peer, _, err := bus.Channel().SubscribeChan(
context.Background(),
SubscribeOptions{Name: "peer", Priority: 10},
)
if err != nil {
t.Fatalf("SubscribeChan peer failed: %v", err)
}
stats := bus.Stats()
got := []string{
stats.SubscriberStats[0].Name,
stats.SubscriberStats[1].Name,
stats.SubscriberStats[2].Name,
}
want := []string{"high", "peer", "low"}
if got[0] != want[0] || got[1] != want[1] || got[2] != want[2] {
t.Fatalf("subscriber order = %v, want %v", got, want)
}
if err := high.Close(); err != nil {
t.Fatalf("Close high failed: %v", err)
}
stats = bus.Stats()
got = []string{
stats.SubscriberStats[0].Name,
stats.SubscriberStats[1].Name,
}
want = []string{"peer", "low"}
if got[0] != want[0] || got[1] != want[1] {
t.Fatalf("subscriber order after unsubscribe = %v, want %v", got, want)
}
if err := peer.Close(); err != nil {
t.Fatalf("Close peer failed: %v", err)
}
if err := low.Close(); err != nil {
t.Fatalf("Close low failed: %v", err)
}
}
func receiveEvent(t *testing.T, ch <-chan Event) Event {
t.Helper()
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("event channel closed before receive")
}
return evt
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
return Event{}
}
}
func closeBus(t *testing.T, bus *EventBus) {
t.Helper()
if err := bus.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
}
+131
View File
@@ -0,0 +1,131 @@
package events
import "strings"
// Filter decides whether an event should pass through an EventChannel.
type Filter func(Event) bool
// ScopeFilter matches selected non-empty fields against Event.Scope.
type ScopeFilter struct {
AgentID string
SessionKey string
TurnID string
Channel string
ChatID string
MessageID string
}
// MatchKind matches events whose kind is in kinds. Empty kinds match all events.
func MatchKind(kinds ...Kind) Filter {
if len(kinds) == 0 {
return matchAll
}
allowed := make(map[Kind]struct{}, len(kinds))
for _, kind := range kinds {
allowed[kind] = struct{}{}
}
return func(evt Event) bool {
_, ok := allowed[evt.Kind]
return ok
}
}
// MatchKindPrefix matches events whose kind starts with prefix.
func MatchKindPrefix(prefix string) Filter {
if prefix == "" {
return matchAll
}
return func(evt Event) bool {
return strings.HasPrefix(evt.Kind.String(), prefix)
}
}
// MatchSource matches events emitted by component and, optionally, one of names.
func MatchSource(component string, names ...string) Filter {
if component == "" && len(names) == 0 {
return matchAll
}
allowedNames := make(map[string]struct{}, len(names))
for _, name := range names {
allowedNames[name] = struct{}{}
}
return func(evt Event) bool {
if component != "" && evt.Source.Component != component {
return false
}
if len(allowedNames) == 0 {
return true
}
_, ok := allowedNames[evt.Source.Name]
return ok
}
}
// MatchScope matches events whose Scope contains all non-empty filter fields.
func MatchScope(scope ScopeFilter) Filter {
if scope == (ScopeFilter{}) {
return matchAll
}
return func(evt Event) bool {
return matchesString(scope.AgentID, evt.Scope.AgentID) &&
matchesString(scope.SessionKey, evt.Scope.SessionKey) &&
matchesString(scope.TurnID, evt.Scope.TurnID) &&
matchesString(scope.Channel, evt.Scope.Channel) &&
matchesString(scope.ChatID, evt.Scope.ChatID) &&
matchesString(scope.MessageID, evt.Scope.MessageID)
}
}
// And combines filters and short-circuits on the first non-match.
func And(filters ...Filter) Filter {
if len(filters) == 0 {
return matchAll
}
return func(evt Event) bool {
for _, filter := range filters {
if filter != nil && !filter(evt) {
return false
}
}
return true
}
}
// Or combines filters and short-circuits on the first match.
func Or(filters ...Filter) Filter {
if len(filters) == 0 {
return matchAll
}
return func(evt Event) bool {
for _, filter := range filters {
if filter == nil || filter(evt) {
return true
}
}
return false
}
}
func matchAll(Event) bool {
return true
}
func matchesString(want, got string) bool {
return want == "" || want == got
}
func matchesFilters(filters []Filter, evt Event) bool {
for _, filter := range filters {
if filter != nil && !filter(evt) {
return false
}
}
return true
}
+96
View File
@@ -0,0 +1,96 @@
package events
import "testing"
func TestFilterKindPrefix(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prefix string
event Event
want bool
}{
{
name: "matches agent prefix",
prefix: "agent.",
event: Event{Kind: KindAgentTurnStart},
want: true,
},
{
name: "rejects different prefix",
prefix: "channel.",
event: Event{Kind: KindAgentTurnStart},
want: false,
},
{
name: "empty prefix matches all",
prefix: "",
event: Event{Kind: KindAgentTurnStart},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := MatchKindPrefix(tt.prefix)(tt.event); got != tt.want {
t.Fatalf("MatchKindPrefix(%q) = %v, want %v", tt.prefix, got, tt.want)
}
})
}
}
func TestFilterScope(t *testing.T) {
t.Parallel()
evt := Event{
Scope: Scope{
AgentID: "agent-a",
SessionKey: "session-1",
TurnID: "turn-1",
Channel: "telegram",
ChatID: "chat-1",
MessageID: "msg-1",
},
}
tests := []struct {
name string
scope ScopeFilter
want bool
}{
{
name: "empty filter matches",
scope: ScopeFilter{},
want: true,
},
{
name: "matches selected fields",
scope: ScopeFilter{
AgentID: "agent-a",
ChatID: "chat-1",
},
want: true,
},
{
name: "rejects mismatched field",
scope: ScopeFilter{
AgentID: "agent-a",
MessageID: "msg-2",
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := MatchScope(tt.scope)(evt); got != tt.want {
t.Fatalf("MatchScope(%+v) = %v, want %v", tt.scope, got, tt.want)
}
})
}
}
+156
View File
@@ -0,0 +1,156 @@
package events
const (
// KindAgentTurnStart is emitted when an agent turn starts.
KindAgentTurnStart Kind = "agent.turn.start"
// KindAgentTurnEnd is emitted when an agent turn ends.
KindAgentTurnEnd Kind = "agent.turn.end"
// KindAgentLLMRequest is emitted before an LLM request.
KindAgentLLMRequest Kind = "agent.llm.request"
// KindAgentLLMDelta is emitted for streaming LLM deltas.
KindAgentLLMDelta Kind = "agent.llm.delta"
// KindAgentLLMResponse is emitted after an LLM response.
KindAgentLLMResponse Kind = "agent.llm.response"
// KindAgentLLMRetry is emitted before retrying an LLM request.
KindAgentLLMRetry Kind = "agent.llm.retry"
// KindAgentContextCompress is emitted when agent context is compressed.
KindAgentContextCompress Kind = "agent.context.compress"
// KindAgentSessionSummarize is emitted when session summarization completes.
KindAgentSessionSummarize Kind = "agent.session.summarize"
// KindAgentToolExecStart is emitted before a tool executes.
KindAgentToolExecStart Kind = "agent.tool.exec_start"
// KindAgentToolExecEnd is emitted after a tool finishes.
KindAgentToolExecEnd Kind = "agent.tool.exec_end"
// KindAgentToolExecSkipped is emitted when a tool call is skipped.
KindAgentToolExecSkipped Kind = "agent.tool.exec_skipped"
// KindAgentSteeringInjected is emitted when steering is injected into context.
KindAgentSteeringInjected Kind = "agent.steering.injected"
// KindAgentFollowUpQueued is emitted when async follow-up input is queued.
KindAgentFollowUpQueued Kind = "agent.follow_up.queued"
// KindAgentInterruptReceived is emitted when a turn interrupt is accepted.
KindAgentInterruptReceived Kind = "agent.interrupt.received"
// KindAgentSubTurnSpawn is emitted when a sub-turn is spawned.
KindAgentSubTurnSpawn Kind = "agent.subturn.spawn"
// KindAgentSubTurnEnd is emitted when a sub-turn ends.
KindAgentSubTurnEnd Kind = "agent.subturn.end"
// KindAgentSubTurnResultDelivered is emitted when a sub-turn result is delivered.
KindAgentSubTurnResultDelivered Kind = "agent.subturn.result_delivered"
// KindAgentSubTurnOrphan is emitted when a sub-turn result cannot be delivered.
KindAgentSubTurnOrphan Kind = "agent.subturn.orphan"
// KindAgentError is emitted when agent execution reports an error.
KindAgentError Kind = "agent.error"
// KindChannelLifecycleStarted is emitted when a channel starts.
KindChannelLifecycleStarted Kind = "channel.lifecycle.started"
// KindChannelLifecycleInitialized is emitted when a channel is initialized.
KindChannelLifecycleInitialized Kind = "channel.lifecycle.initialized"
// KindChannelLifecycleStartFailed is emitted when a channel fails to start.
KindChannelLifecycleStartFailed Kind = "channel.lifecycle.start_failed"
// KindChannelLifecycleStopped is emitted when a channel stops.
KindChannelLifecycleStopped Kind = "channel.lifecycle.stopped"
// KindChannelWebhookRegistered is emitted when a channel webhook is registered.
KindChannelWebhookRegistered Kind = "channel.webhook.registered"
// KindChannelWebhookUnregistered is emitted when a channel webhook is unregistered.
KindChannelWebhookUnregistered Kind = "channel.webhook.unregistered"
// KindChannelMessageOutboundQueued is emitted when an outbound message is queued.
KindChannelMessageOutboundQueued Kind = "channel.message.outbound_queued"
// KindChannelMessageOutboundSent is emitted when an outbound channel message is sent.
KindChannelMessageOutboundSent Kind = "channel.message.outbound_sent"
// KindChannelMessageOutboundFailed is emitted when an outbound channel message fails.
KindChannelMessageOutboundFailed Kind = "channel.message.outbound_failed"
// KindChannelRateLimited is emitted when channel rate limiting blocks delivery.
KindChannelRateLimited Kind = "channel.rate_limited"
// KindBusPublishFailed is emitted when message bus publish fails.
KindBusPublishFailed Kind = "bus.publish.failed"
// KindBusCloseStarted is emitted when message bus close starts.
KindBusCloseStarted Kind = "bus.close.started"
// KindBusCloseCompleted is emitted when message bus close completes.
KindBusCloseCompleted Kind = "bus.close.completed"
// KindBusCloseDrained is emitted when message bus close drains buffered messages.
KindBusCloseDrained Kind = "bus.close.drained"
// KindGatewayStart is emitted when gateway startup reaches runtime bootstrap.
KindGatewayStart Kind = "gateway.start"
// KindGatewayReady is emitted when gateway services are started and ready.
KindGatewayReady Kind = "gateway.ready"
// KindGatewayShutdown is emitted when gateway shutdown starts.
KindGatewayShutdown Kind = "gateway.shutdown"
// KindGatewayReloadStarted is emitted when gateway reload starts.
KindGatewayReloadStarted Kind = "gateway.reload.started"
// KindGatewayReloadCompleted is emitted when gateway reload completes.
KindGatewayReloadCompleted Kind = "gateway.reload.completed"
// KindGatewayReloadFailed is emitted when gateway reload fails.
KindGatewayReloadFailed Kind = "gateway.reload.failed"
// KindMCPServerConnected is emitted when an MCP server connects.
KindMCPServerConnected Kind = "mcp.server.connected"
// KindMCPServerConnecting is emitted before connecting to an MCP server.
KindMCPServerConnecting Kind = "mcp.server.connecting"
// KindMCPServerFailed is emitted when an MCP server fails.
KindMCPServerFailed Kind = "mcp.server.failed"
// KindMCPToolDiscovered is emitted when an MCP tool is discovered.
KindMCPToolDiscovered Kind = "mcp.tool.discovered"
// KindMCPToolCallStart is emitted when an MCP tool call starts.
KindMCPToolCallStart Kind = "mcp.tool.call.start"
// KindMCPToolCallEnd is emitted when an MCP tool call ends.
KindMCPToolCallEnd Kind = "mcp.tool.call.end"
)
var knownKinds = []Kind{
KindAgentTurnStart,
KindAgentTurnEnd,
KindAgentLLMRequest,
KindAgentLLMDelta,
KindAgentLLMResponse,
KindAgentLLMRetry,
KindAgentContextCompress,
KindAgentSessionSummarize,
KindAgentToolExecStart,
KindAgentToolExecEnd,
KindAgentToolExecSkipped,
KindAgentSteeringInjected,
KindAgentFollowUpQueued,
KindAgentInterruptReceived,
KindAgentSubTurnSpawn,
KindAgentSubTurnEnd,
KindAgentSubTurnResultDelivered,
KindAgentSubTurnOrphan,
KindAgentError,
KindChannelLifecycleStarted,
KindChannelLifecycleInitialized,
KindChannelLifecycleStartFailed,
KindChannelLifecycleStopped,
KindChannelWebhookRegistered,
KindChannelWebhookUnregistered,
KindChannelMessageOutboundQueued,
KindChannelMessageOutboundSent,
KindChannelMessageOutboundFailed,
KindChannelRateLimited,
KindBusPublishFailed,
KindBusCloseStarted,
KindBusCloseCompleted,
KindBusCloseDrained,
KindGatewayStart,
KindGatewayReady,
KindGatewayShutdown,
KindGatewayReloadStarted,
KindGatewayReloadCompleted,
KindGatewayReloadFailed,
KindMCPServerConnected,
KindMCPServerConnecting,
KindMCPServerFailed,
KindMCPToolDiscovered,
KindMCPToolCallStart,
KindMCPToolCallEnd,
}
// KnownKinds returns the runtime event kinds declared by this package.
func KnownKinds() []Kind {
return append([]Kind(nil), knownKinds...)
}
+26
View File
@@ -0,0 +1,26 @@
package events
// Stats reports aggregate EventBus counters.
type Stats struct {
Published uint64
Matched uint64
Delivered uint64
Dropped uint64
Blocked uint64
Closed bool
Subscribers int
SubscriberStats []SubscriberStats
}
// SubscriberStats reports counters for one subscription.
type SubscriberStats struct {
ID uint64
Name string
Received uint64
Handled uint64
Failed uint64
Dropped uint64
Panicked uint64
TimedOut uint64
}
+459
View File
@@ -0,0 +1,459 @@
package events
import (
"context"
"errors"
"log"
"sync"
"sync/atomic"
"time"
)
const defaultSubscriberBuffer = 16
var (
// ErrBusClosed is returned when subscribing to a closed event bus.
ErrBusClosed = errors.New("events: bus is closed")
// ErrNilHandler is returned when subscribing without a handler.
ErrNilHandler = errors.New("events: handler is nil")
)
// Handler processes a runtime event delivered to a subscription.
type Handler func(context.Context, Event) error
// SubscribeOptions controls how a subscription receives events.
type SubscribeOptions struct {
Name string
Buffer int
Priority int
Concurrency ConcurrencyKind
Backpressure BackpressurePolicy
// Timeout bounds how long the subscription worker waits for one handler call.
// Handlers should still honor ctx cancellation; timed-out calls keep running
// until their handler returns.
Timeout time.Duration
PanicPolicy PanicPolicy
}
// ConcurrencyKind controls how handler subscriptions process queued events.
type ConcurrencyKind string
const (
// Concurrent processes each event in its own goroutine.
Concurrent ConcurrencyKind = "concurrent"
// Locked processes events sequentially in subscription order.
Locked ConcurrencyKind = "locked"
// Keyed is reserved for keyed sequential processing and currently behaves as Locked.
Keyed ConcurrencyKind = "keyed"
)
// BackpressurePolicy controls delivery when a subscription queue is full.
type BackpressurePolicy string
const (
// DropNewest drops the event being published when the queue is full.
DropNewest BackpressurePolicy = "drop_newest"
// DropOldest drops one queued event and enqueues the event being published.
DropOldest BackpressurePolicy = "drop_oldest"
// Block waits for queue capacity until Publish's context is canceled.
Block BackpressurePolicy = "block"
)
// PanicPolicy controls handler panic behavior.
type PanicPolicy string
const (
// RecoverAndLog recovers handler panics and records them in subscription stats.
RecoverAndLog PanicPolicy = "recover_and_log"
// Crash lets handler panics propagate from the worker goroutine.
Crash PanicPolicy = "crash"
)
// Subscription represents an active event subscription.
type Subscription interface {
ID() uint64
Name() string
Close() error
Done() <-chan struct{}
Stats() SubscriberStats
}
type subscriberCounters struct {
received atomic.Uint64
handled atomic.Uint64
failed atomic.Uint64
dropped atomic.Uint64
panicked atomic.Uint64
timedOut atomic.Uint64
}
type eventSubscription struct {
bus *EventBus
id uint64
name string
opts SubscribeOptions
filters []Filter
handler Handler
once bool
ch chan Event
done chan struct{}
closing chan struct{}
closeOnce sync.Once
doneOnce sync.Once
mu sync.RWMutex
closed bool
wg sync.WaitGroup
blockWG sync.WaitGroup
counters subscriberCounters
}
type handlerResult struct {
err error
panicked bool
}
func normalizeSubscribeOptions(opts SubscribeOptions) SubscribeOptions {
if opts.Buffer <= 0 {
opts.Buffer = defaultSubscriberBuffer
}
if opts.Concurrency == "" {
opts.Concurrency = Locked
}
if opts.Backpressure == "" {
opts.Backpressure = DropNewest
}
if opts.PanicPolicy == "" {
opts.PanicPolicy = RecoverAndLog
}
return opts
}
func newSubscription(
bus *EventBus,
id uint64,
filters []Filter,
opts SubscribeOptions,
handler Handler,
once bool,
) *eventSubscription {
opts = normalizeSubscribeOptions(opts)
return &eventSubscription{
bus: bus,
id: id,
name: opts.Name,
opts: opts,
filters: append([]Filter(nil), filters...),
handler: handler,
once: once,
ch: make(chan Event, opts.Buffer),
done: make(chan struct{}),
closing: make(chan struct{}),
}
}
// ID returns the subscription identifier.
func (s *eventSubscription) ID() uint64 {
if s == nil {
return 0
}
return s.id
}
// Name returns the subscription name.
func (s *eventSubscription) Name() string {
if s == nil {
return ""
}
return s.name
}
// Close removes the subscription and closes its delivery channel.
func (s *eventSubscription) Close() error {
if s == nil || s.bus == nil {
return nil
}
s.bus.unsubscribe(s.id)
return nil
}
// Done returns a channel closed after the subscription has stopped processing.
func (s *eventSubscription) Done() <-chan struct{} {
if s == nil {
ch := make(chan struct{})
close(ch)
return ch
}
return s.done
}
// Stats returns a snapshot of the subscription counters.
func (s *eventSubscription) Stats() SubscriberStats {
if s == nil {
return SubscriberStats{}
}
return SubscriberStats{
ID: s.id,
Name: s.name,
Received: s.counters.received.Load(),
Handled: s.counters.handled.Load(),
Failed: s.counters.failed.Load(),
Dropped: s.counters.dropped.Load(),
Panicked: s.counters.panicked.Load(),
TimedOut: s.counters.timedOut.Load(),
}
}
func (s *eventSubscription) run(ctx context.Context) {
defer func() {
s.wg.Wait()
s.closeDone()
}()
for evt := range s.ch {
s.dispatch(ctx, evt)
if s.once {
_ = s.Close()
return
}
}
}
func (s *eventSubscription) dispatch(ctx context.Context, evt Event) {
switch s.opts.Concurrency {
case Concurrent:
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handle(ctx, evt)
}()
case Keyed:
// TODO: replace this with keyed executors when runtime events need
// per-scope ordering with cross-scope concurrency.
s.handle(ctx, evt)
default:
s.handle(ctx, evt)
}
}
func (s *eventSubscription) handle(ctx context.Context, evt Event) {
if ctx == nil {
ctx = context.Background()
}
if s.opts.Timeout <= 0 {
s.recordHandlerResult(ctx, s.invokeHandler(ctx, evt))
return
}
ctx, cancel := context.WithTimeout(ctx, s.opts.Timeout)
defer cancel()
done := make(chan handlerResult, 1)
go func() {
done <- s.invokeHandler(ctx, evt)
}()
select {
case result := <-done:
s.recordHandlerResult(ctx, result)
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
s.counters.timedOut.Add(1)
}
s.counters.failed.Add(1)
}
}
func (s *eventSubscription) invokeHandler(ctx context.Context, evt Event) (result handlerResult) {
if s.opts.PanicPolicy != Crash {
defer func() {
if recovered := recover(); recovered != nil {
s.counters.panicked.Add(1)
result.panicked = true
log.Printf("events: subscriber %q recovered panic: %v", s.name, recovered)
}
}()
}
result.err = s.handler(ctx, evt)
return result
}
func (s *eventSubscription) recordHandlerResult(ctx context.Context, result handlerResult) {
if result.panicked {
return
}
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
s.counters.timedOut.Add(1)
}
if result.err != nil {
s.counters.failed.Add(1)
return
}
s.counters.handled.Add(1)
}
func (s *eventSubscription) watchContext(ctx context.Context) {
if ctx == nil {
return
}
go func() {
select {
case <-ctx.Done():
_ = s.Close()
case <-s.done:
}
}()
}
func (s *eventSubscription) closeInput() {
s.closeOnce.Do(func() {
close(s.closing)
s.mu.Lock()
s.closed = true
s.mu.Unlock()
s.blockWG.Wait()
s.mu.Lock()
close(s.ch)
s.mu.Unlock()
if s.handler == nil {
s.closeDone()
}
})
}
func (s *eventSubscription) closeDone() {
s.doneOnce.Do(func() {
close(s.done)
})
}
type deliveryResult struct {
delivered int
dropped int
blocked int
closed bool
}
func (s *eventSubscription) enqueue(ctx context.Context, evt Event, nonBlocking bool) deliveryResult {
if ctx == nil {
ctx = context.Background()
}
if nonBlocking {
return s.enqueueNonBlocking(evt)
}
if s.opts.Backpressure == Block {
return s.enqueueBlocking(ctx, evt)
}
s.mu.RLock()
defer s.mu.RUnlock()
if s.closed {
return deliveryResult{closed: true}
}
s.counters.received.Add(1)
switch s.opts.Backpressure {
case DropOldest:
return s.enqueueDropOldest(evt)
default:
return s.enqueueDropNewest(evt)
}
}
func (s *eventSubscription) enqueueBlocking(ctx context.Context, evt Event) deliveryResult {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return deliveryResult{closed: true}
}
s.blockWG.Add(1)
s.counters.received.Add(1)
s.mu.Unlock()
defer s.blockWG.Done()
return s.enqueueBlock(ctx, evt)
}
func (s *eventSubscription) enqueueNonBlocking(evt Event) deliveryResult {
s.mu.RLock()
defer s.mu.RUnlock()
if s.closed {
return deliveryResult{closed: true}
}
s.counters.received.Add(1)
if s.opts.Backpressure == DropOldest {
return s.enqueueDropOldest(evt)
}
return s.enqueueDropNewest(evt)
}
func (s *eventSubscription) enqueueDropNewest(evt Event) deliveryResult {
select {
case <-s.closing:
return deliveryResult{closed: true}
default:
}
select {
case s.ch <- evt:
return deliveryResult{delivered: 1}
default:
s.counters.dropped.Add(1)
return deliveryResult{dropped: 1}
}
}
func (s *eventSubscription) enqueueDropOldest(evt Event) deliveryResult {
select {
case <-s.closing:
return deliveryResult{closed: true}
default:
}
select {
case s.ch <- evt:
return deliveryResult{delivered: 1}
default:
}
dropped := 0
select {
case <-s.ch:
s.counters.dropped.Add(1)
dropped = 1
default:
}
select {
case <-s.closing:
return deliveryResult{dropped: dropped, closed: true}
case s.ch <- evt:
return deliveryResult{delivered: 1, dropped: dropped}
default:
s.counters.dropped.Add(1)
return deliveryResult{dropped: dropped + 1}
}
}
func (s *eventSubscription) enqueueBlock(ctx context.Context, evt Event) deliveryResult {
select {
case <-s.closing:
return deliveryResult{closed: true}
case s.ch <- evt:
return deliveryResult{delivered: 1}
case <-ctx.Done():
s.counters.dropped.Add(1)
return deliveryResult{dropped: 1, blocked: 1}
}
}
+254
View File
@@ -0,0 +1,254 @@
package events
import (
"context"
"sync/atomic"
"testing"
"time"
)
func TestSubscribeOnceClosesAfterFirstEvent(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
var handled atomic.Uint64
sub, err := bus.Channel().SubscribeOnce(
context.Background(),
SubscribeOptions{Name: "once", Buffer: 2},
func(context.Context, Event) error {
handled.Add(1)
return nil
},
)
if err != nil {
t.Fatalf("SubscribeOnce failed: %v", err)
}
bus.Publish(context.Background(), Event{Kind: KindAgentTurnStart})
waitForSubscriptionDone(t, sub)
bus.Publish(context.Background(), Event{Kind: KindAgentTurnEnd})
if got := handled.Load(); got != 1 {
t.Fatalf("handled = %d, want 1", got)
}
if got := sub.Stats().Handled; got != 1 {
t.Fatalf("subscription handled = %d, want 1", got)
}
}
func TestUnsubscribeClosesChannel(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
sub, ch, err := bus.Channel().SubscribeChan(context.Background(), SubscribeOptions{Name: "chan"})
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
if err := sub.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
select {
case _, ok := <-ch:
if ok {
t.Fatal("channel is open, want closed")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for channel close")
}
waitForSubscriptionDone(t, sub)
}
func TestBlockBackpressureCloseUnblocksPublisher(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
sub, _, err := bus.Channel().SubscribeChan(context.Background(), SubscribeOptions{
Name: "block-close",
Buffer: 1,
Backpressure: Block,
})
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
first := bus.Publish(context.Background(), Event{Kind: Kind("test.first")})
if first.Delivered != 1 {
t.Fatalf("first Publish = %+v, want one delivered event", first)
}
publishStarted := make(chan struct{})
publishReturned := make(chan PublishResult, 1)
go func() {
close(publishStarted)
publishReturned <- bus.Publish(context.Background(), Event{Kind: Kind("test.second")})
}()
<-publishStarted
waitForStat(t, func() uint64 {
return sub.Stats().Received
}, 2)
select {
case result := <-publishReturned:
t.Fatalf("blocking Publish returned before close: %+v", result)
default:
}
closeReturned := make(chan error, 1)
go func() {
closeReturned <- sub.Close()
}()
select {
case err := <-closeReturned:
if err != nil {
t.Fatalf("Close failed: %v", err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for Close to unblock")
}
select {
case <-publishReturned:
case <-time.After(time.Second):
t.Fatal("timed out waiting for blocking Publish to return after close")
}
waitForSubscriptionDone(t, sub)
}
func TestHandlerPanicRecovered(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
sub, err := bus.Channel().Subscribe(
context.Background(),
SubscribeOptions{Name: "panic", Buffer: 1},
func(context.Context, Event) error {
panic("boom")
},
)
if err != nil {
t.Fatalf("Subscribe failed: %v", err)
}
bus.Publish(context.Background(), Event{Kind: KindAgentError})
waitForStat(t, func() uint64 {
return sub.Stats().Panicked
}, 1)
}
func TestLockedHandlerProcessesSequentially(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
var active atomic.Int64
var maxActive atomic.Int64
sub, err := bus.Channel().Subscribe(
context.Background(),
SubscribeOptions{Name: "locked", Buffer: 8, Concurrency: Locked},
func(context.Context, Event) error {
current := active.Add(1)
for {
currentMax := maxActive.Load()
if current <= currentMax || maxActive.CompareAndSwap(currentMax, current) {
break
}
}
time.Sleep(10 * time.Millisecond)
active.Add(-1)
return nil
},
)
if err != nil {
t.Fatalf("Subscribe failed: %v", err)
}
for i := 0; i < 5; i++ {
bus.Publish(context.Background(), Event{Kind: KindAgentLLMDelta})
}
waitForStat(t, func() uint64 {
return sub.Stats().Handled
}, 5)
if got := maxActive.Load(); got != 1 {
t.Fatalf("max active handlers = %d, want 1", got)
}
}
func TestHandlerTimeoutDoesNotWedgeLockedSubscription(t *testing.T) {
t.Parallel()
bus := NewBus()
defer closeBus(t, bus)
releaseFirst := make(chan struct{})
defer close(releaseFirst)
var calls atomic.Uint64
sub, err := bus.Channel().Subscribe(
context.Background(),
SubscribeOptions{Name: "timeout", Buffer: 2, Concurrency: Locked, Timeout: 20 * time.Millisecond},
func(context.Context, Event) error {
if calls.Add(1) == 1 {
<-releaseFirst
}
return nil
},
)
if err != nil {
t.Fatalf("Subscribe failed: %v", err)
}
bus.Publish(context.Background(), Event{Kind: Kind("test.first")})
waitForStat(t, func() uint64 {
return sub.Stats().TimedOut
}, 1)
bus.Publish(context.Background(), Event{Kind: Kind("test.second")})
waitForStat(t, func() uint64 {
return sub.Stats().Handled
}, 1)
if got := sub.Stats().Failed; got != 1 {
t.Fatalf("subscription failed = %d, want timeout failure", got)
}
}
func waitForSubscriptionDone(t *testing.T, sub Subscription) {
t.Helper()
select {
case <-sub.Done():
case <-time.After(time.Second):
t.Fatal("timed out waiting for subscription to stop")
}
}
func waitForStat(t *testing.T, stat func() uint64, want uint64) {
t.Helper()
deadline := time.After(time.Second)
ticker := time.NewTicker(time.Millisecond)
defer ticker.Stop()
for {
if got := stat(); got >= want {
return
}
select {
case <-ticker.C:
case <-deadline:
t.Fatalf("timed out waiting for stat >= %d", want)
}
}
}
+77
View File
@@ -0,0 +1,77 @@
package events
import "time"
// Kind identifies a runtime event category.
type Kind string
// String returns the string representation of the event kind.
func (k Kind) String() string {
return string(k)
}
// Event is the runtime event envelope shared across PicoClaw components.
type Event struct {
ID string `json:"id"`
Kind Kind `json:"kind"`
Time time.Time `json:"time"`
Source Source `json:"source"`
Scope Scope `json:"scope,omitempty"`
Correlation Correlation `json:"correlation,omitempty"`
Severity Severity `json:"severity,omitempty"`
Payload any `json:"payload,omitempty"`
Attrs map[string]any `json:"attrs,omitempty"`
}
// Source identifies the component that emitted an event.
type Source struct {
Component string `json:"component"`
Name string `json:"name,omitempty"`
}
// Scope identifies the runtime ownership of an event.
//
// Scope is intentionally limited to agent, session, turn, channel, chat,
// message, and sender identity. Tool, provider, model, and MCP details belong
// in Source, Payload, or Attrs.
type Scope struct {
RuntimeID string `json:"runtime_id,omitempty"`
AgentID string `json:"agent_id,omitempty"`
SessionKey string `json:"session_key,omitempty"`
TurnID string `json:"turn_id,omitempty"`
Channel string `json:"channel,omitempty"`
Account string `json:"account,omitempty"`
ChatID string `json:"chat_id,omitempty"`
TopicID string `json:"topic_id,omitempty"`
SpaceID string `json:"space_id,omitempty"`
SpaceType string `json:"space_type,omitempty"`
ChatType string `json:"chat_type,omitempty"`
SenderID string `json:"sender_id,omitempty"`
MessageID string `json:"message_id,omitempty"`
}
// Correlation carries cross-event tracing fields.
type Correlation struct {
TraceID string `json:"trace_id,omitempty"`
ParentTurnID string `json:"parent_turn_id,omitempty"`
RequestID string `json:"request_id,omitempty"`
ReplyToID string `json:"reply_to_id,omitempty"`
}
// Severity describes the operational severity of an event.
type Severity string
const (
// SeverityDebug is used for verbose diagnostic events.
SeverityDebug Severity = "debug"
// SeverityInfo is used for normal lifecycle and activity events.
SeverityInfo Severity = "info"
// SeverityWarn is used for recoverable abnormal events.
SeverityWarn Severity = "warn"
// SeverityError is used for failed operations and unrecoverable events.
SeverityError Severity = "error"
)
+53
View File
@@ -0,0 +1,53 @@
package gateway
import (
"time"
"github.com/sipeed/picoclaw/pkg/agent"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
type gatewayEventPayload struct {
DurationMS int64 `json:"duration_ms,omitempty"`
Error string `json:"error,omitempty"`
}
func publishGatewayEvent(
al *agent.AgentLoop,
kind runtimeevents.Kind,
startedAt time.Time,
err error,
) {
if al == nil || al.RuntimeEventBus() == nil {
return
}
severity := runtimeevents.SeverityInfo
payload := gatewayEventPayload{}
if !startedAt.IsZero() {
payload.DurationMS = time.Since(startedAt).Milliseconds()
}
if err != nil {
severity = runtimeevents.SeverityError
payload.Error = err.Error()
}
al.RuntimeEventBus().PublishNonBlocking(runtimeevents.Event{
Kind: kind,
Source: runtimeevents.Source{Component: "gateway"},
Severity: severity,
Payload: payload,
Attrs: gatewayEventAttrs(payload),
})
}
func gatewayEventAttrs(payload gatewayEventPayload) map[string]any {
attrs := map[string]any{}
if payload.DurationMS > 0 {
attrs["duration_ms"] = payload.DurationMS
}
if payload.Error != "" {
attrs["error"] = payload.Error
}
return attrs
}
+31 -4
View File
@@ -39,6 +39,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/devices"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -114,6 +115,7 @@ func (p *startupBlockedProvider) GetDefaultModel() string {
// Run starts the gateway runtime using the configuration loaded from configPath.
func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runErr error) {
startedAt := time.Now()
panicPath := filepath.Join(homePath, logPath, panicFile)
panicFunc, err := logger.InitPanic(panicPath)
if err != nil {
@@ -197,6 +199,8 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
msgBus := bus.NewMessageBus()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
msgBus.SetEventPublisher(agentLoop.RuntimeEventBus())
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayStart, startedAt, nil)
fmt.Println("\n📦 Agent Status:")
startupInfo := agentLoop.GetStartupInfo()
@@ -216,6 +220,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
if err != nil {
return err
}
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayReady, startedAt, nil)
closeListeners = false
// Setup manual reload channel for /reload endpoint
@@ -262,7 +267,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
select {
case <-sigChan:
logger.Info("Shutting down...")
shutdownGateway(runningServices, agentLoop, provider, true)
shutdownGateway(runningServices, agentLoop, provider, msgBus, true)
return nil
case newCfg := <-configReloadChan:
if !runningServices.reloading.CompareAndSwap(false, true) {
@@ -312,10 +317,20 @@ func executeReload(
msgBus *bus.MessageBus,
allowEmptyStartup bool,
debug bool,
) error {
) (err error) {
startedAt := time.Now()
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayReloadStarted, startedAt, nil)
defer runningServices.reloading.Store(false)
defer func() {
if err != nil {
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayReloadFailed, startedAt, err)
return
}
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayReloadCompleted, startedAt, nil)
}()
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
err = handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
return err
}
func createStartupProvider(
@@ -383,7 +398,12 @@ func setupAndStartServices(
fms.Start()
}
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
runningServices.ChannelManager, err = channels.NewManager(
cfg,
msgBus,
runningServices.MediaStore,
channels.WithRuntimeEvents(agentLoop.RuntimeEventBus()),
)
if err != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
@@ -490,14 +510,21 @@ func shutdownGateway(
runningServices *services,
agentLoop *agent.AgentLoop,
provider providers.LLMProvider,
msgBus *bus.MessageBus,
fullShutdown bool,
) {
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayShutdown, time.Time{}, nil)
if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown {
cp.Close()
}
stopAndCleanupServices(runningServices, gracefulShutdownTimeout, false)
if fullShutdown && msgBus != nil {
msgBus.Close()
}
agentLoop.Stop()
agentLoop.Close()
+103
View File
@@ -1,14 +1,20 @@
package gateway
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
func TestRun_StartupFailuresReturnErrorAndEmitStructuredLog(t *testing.T) {
@@ -106,3 +112,100 @@ func TestGatewayRunStartupFailureHelper(t *testing.T) {
fmt.Fprintln(os.Stdout, err.Error())
os.Exit(0)
}
func TestPublishGatewayEvent(t *testing.T) {
eventBus := runtimeevents.NewBus()
t.Cleanup(func() {
if err := eventBus.Close(); err != nil {
t.Fatalf("Close runtime event bus: %v", err)
}
})
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
sub, eventsCh, err := eventBus.Channel().OfKind(runtimeevents.KindGatewayStart).SubscribeChan(
ctx,
runtimeevents.SubscribeOptions{Name: "gateway-test", Buffer: 4},
)
if err != nil {
t.Fatalf("SubscribeChan() error = %v", err)
}
t.Cleanup(func() {
if err := sub.Close(); err != nil {
t.Fatalf("Close subscription: %v", err)
}
})
al := agent.NewAgentLoop(
config.DefaultConfig(),
bus.NewMessageBus(),
&startupBlockedProvider{reason: "not used"},
agent.WithRuntimeEvents(eventBus),
)
t.Cleanup(al.Close)
startedAt := time.Now().Add(-1500 * time.Millisecond)
publishGatewayEvent(al, runtimeevents.KindGatewayStart, startedAt, nil)
evt := receiveGatewayRuntimeEvent(t, eventsCh)
if evt.Kind != runtimeevents.KindGatewayStart ||
evt.Source.Component != "gateway" ||
evt.Severity != runtimeevents.SeverityInfo {
t.Fatalf("gateway event = %+v", evt)
}
payload, ok := evt.Payload.(gatewayEventPayload)
if !ok {
t.Fatalf("payload type = %T, want gatewayEventPayload", evt.Payload)
}
if payload.DurationMS <= 0 {
t.Fatalf("DurationMS = %d, want positive", payload.DurationMS)
}
if evt.Attrs["duration_ms"] == nil {
t.Fatalf("gateway event attrs missing duration_ms: %#v", evt.Attrs)
}
}
func TestShutdownGatewayClosesMessageBus(t *testing.T) {
msgBus := bus.NewMessageBus()
al := agent.NewAgentLoop(
config.DefaultConfig(),
msgBus,
&startupBlockedProvider{reason: "not used"},
)
msgBus.SetEventPublisher(al.RuntimeEventBus())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sub, eventsCh, err := al.RuntimeEventBus().Channel().OfKind(runtimeevents.KindBusCloseCompleted).SubscribeChan(
ctx,
runtimeevents.SubscribeOptions{Name: "bus-close-test", Buffer: 4},
)
if err != nil {
t.Fatalf("SubscribeChan() error = %v", err)
}
defer func() {
_ = sub.Close()
}()
shutdownGateway(&services{}, al, &startupBlockedProvider{reason: "not used"}, msgBus, true)
evt := receiveGatewayRuntimeEvent(t, eventsCh)
if evt.Kind != runtimeevents.KindBusCloseCompleted {
t.Fatalf("shutdown event kind = %q, want %q", evt.Kind, runtimeevents.KindBusCloseCompleted)
}
if err := msgBus.PublishVoiceControl(context.Background(), bus.VoiceControl{}); !errors.Is(err, bus.ErrBusClosed) {
t.Fatalf("PublishVoiceControl after shutdown error = %v, want %v", err, bus.ErrBusClosed)
}
}
func receiveGatewayRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
t.Helper()
select {
case evt := <-ch:
return evt
case <-time.After(time.Second):
t.Fatal("timed out waiting for gateway runtime event")
return runtimeevents.Event{}
}
}
+92
View File
@@ -0,0 +1,92 @@
package mcp
import (
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
func (m *Manager) publishServerEvent(
kind runtimeevents.Kind,
serverName string,
cfg config.MCPServerConfig,
toolCount int,
err error,
) {
if m == nil || m.runtimeEvents == nil {
return
}
severity := runtimeevents.SeverityInfo
if err != nil {
severity = runtimeevents.SeverityError
}
payload := ServerEventPayload{
Server: serverName,
Type: mcpTransportType(cfg),
URL: cfg.URL,
Command: cfg.Command,
ToolCount: toolCount,
}
if err != nil {
payload.Error = err.Error()
}
m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{
Kind: kind,
Source: runtimeevents.Source{Component: "mcp", Name: serverName},
Severity: severity,
Payload: payload,
Attrs: mcpServerEventAttrs(payload),
})
}
func (m *Manager) publishToolDiscovered(serverName string, cfg config.MCPServerConfig, toolName string) {
if m == nil || m.runtimeEvents == nil {
return
}
payload := ServerEventPayload{
Server: serverName,
Type: mcpTransportType(cfg),
URL: cfg.URL,
Command: cfg.Command,
Tool: toolName,
}
m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{
Kind: runtimeevents.KindMCPToolDiscovered,
Source: runtimeevents.Source{Component: "mcp", Name: serverName},
Severity: runtimeevents.SeverityInfo,
Payload: payload,
Attrs: mcpServerEventAttrs(payload),
})
}
func mcpServerEventAttrs(payload ServerEventPayload) map[string]any {
attrs := map[string]any{}
setMCPAttrString(attrs, "server", payload.Server)
setMCPAttrString(attrs, "type", payload.Type)
setMCPAttrString(attrs, "tool", payload.Tool)
if payload.ToolCount > 0 {
attrs["tool_count"] = payload.ToolCount
}
setMCPAttrString(attrs, "error", payload.Error)
return attrs
}
func setMCPAttrString(attrs map[string]any, key, value string) {
if value != "" {
attrs[key] = value
}
}
func mcpTransportType(cfg config.MCPServerConfig) string {
if cfg.Type != "" {
return cfg.Type
}
if cfg.URL != "" {
return "sse"
}
if cfg.Command != "" {
return "stdio"
}
return ""
}
+46 -6
View File
@@ -16,6 +16,7 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -127,19 +128,47 @@ type ServerConnection struct {
// Manager manages multiple MCP server connections
type Manager struct {
servers map[string]*ServerConnection
mu sync.RWMutex
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
wg sync.WaitGroup // tracks in-flight CallTool calls
servers map[string]*ServerConnection
runtimeEvents runtimeevents.Bus
mu sync.RWMutex
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
wg sync.WaitGroup // tracks in-flight CallTool calls
}
var connectServerFunc = connectServer
// ManagerOption configures an MCP manager.
type ManagerOption func(*Manager)
// WithRuntimeEvents injects the runtime event bus used for MCP observations.
func WithRuntimeEvents(eventBus runtimeevents.Bus) ManagerOption {
return func(m *Manager) {
m.runtimeEvents = eventBus
}
}
// ServerEventPayload describes MCP server connection events.
type ServerEventPayload struct {
Server string `json:"server"`
Type string `json:"type,omitempty"`
URL string `json:"url,omitempty"`
Command string `json:"command,omitempty"`
Tool string `json:"tool,omitempty"`
ToolCount int `json:"tool_count,omitempty"`
Error string `json:"error,omitempty"`
}
// NewManager creates a new MCP manager
func NewManager() *Manager {
return &Manager{
func NewManager(opts ...ManagerOption) *Manager {
m := &Manager{
servers: make(map[string]*ServerConnection),
}
for _, opt := range opts {
if opt != nil {
opt(m)
}
}
return m
}
// LoadFromConfig loads MCP servers from configuration
@@ -264,8 +293,10 @@ func (m *Manager) ConnectServer(
name string,
cfg config.MCPServerConfig,
) error {
m.publishServerEvent(runtimeevents.KindMCPServerConnecting, name, cfg, 0, nil)
conn, err := connectServerFunc(ctx, name, cfg)
if err != nil {
m.publishServerEvent(runtimeevents.KindMCPServerFailed, name, cfg, 0, err)
return err
}
@@ -274,10 +305,19 @@ func (m *Manager) ConnectServer(
if m.closed.Load() {
_ = conn.Session.Close()
m.publishServerEvent(runtimeevents.KindMCPServerFailed, name, cfg, 0, fmt.Errorf("manager is closed"))
return fmt.Errorf("manager is closed")
}
m.servers[name] = conn
for _, tool := range conn.Tools {
toolName := ""
if tool != nil {
toolName = tool.Name
}
m.publishToolDiscovered(name, cfg, toolName)
}
m.publishServerEvent(runtimeevents.KindMCPServerConnected, name, cfg, len(conn.Tools), nil)
return nil
}
+91
View File
@@ -10,11 +10,13 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/config"
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
)
func TestLoadEnvFile(t *testing.T) {
@@ -248,6 +250,95 @@ func TestNewManager_InitialState(t *testing.T) {
}
}
func TestConnectServerPublishesRuntimeEvents(t *testing.T) {
originalConnectServerFunc := connectServerFunc
t.Cleanup(func() {
connectServerFunc = originalConnectServerFunc
})
eventBus := runtimeevents.NewBus()
defer func() {
if err := eventBus.Close(); err != nil {
t.Errorf("event bus close failed: %v", err)
}
}()
_, eventsCh, err := eventBus.Channel().OfKind(
runtimeevents.KindMCPServerConnected,
runtimeevents.KindMCPServerFailed,
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "mcp-events", Buffer: 2})
if err != nil {
t.Fatalf("SubscribeChan failed: %v", err)
}
connectServerFunc = func(
_ context.Context,
name string,
cfg config.MCPServerConfig,
) (*ServerConnection, error) {
if name == "bad" {
return nil, fmt.Errorf("connect failed")
}
return &ServerConnection{
Name: name,
Config: cfg,
Tools: []*sdkmcp.Tool{{Name: "echo"}},
}, nil
}
mgr := NewManager(WithRuntimeEvents(eventBus))
err = mgr.ConnectServer(context.Background(), "good", config.MCPServerConfig{
Type: "stdio",
Command: "echo",
})
if err != nil {
t.Fatalf("ConnectServer(good) error = %v", err)
}
connected := receiveMCPRuntimeEvent(t, eventsCh)
if connected.Kind != runtimeevents.KindMCPServerConnected ||
connected.Source.Name != "good" ||
connected.Severity != runtimeevents.SeverityInfo {
t.Fatalf("connected event = %+v", connected)
}
if connected.Attrs["server"] != "good" ||
connected.Attrs["type"] != "stdio" ||
connected.Attrs["tool_count"] != 1 {
t.Fatalf("connected attrs = %#v", connected.Attrs)
}
err = mgr.ConnectServer(context.Background(), "bad", config.MCPServerConfig{
Type: "stdio",
Command: "echo",
})
if err == nil {
t.Fatal("expected ConnectServer(bad) to fail")
}
failed := receiveMCPRuntimeEvent(t, eventsCh)
if failed.Kind != runtimeevents.KindMCPServerFailed ||
failed.Source.Name != "bad" ||
failed.Severity != runtimeevents.SeverityError {
t.Fatalf("failed event = %+v", failed)
}
if failed.Attrs["server"] != "bad" || failed.Attrs["error"] != "connect failed" {
t.Fatalf("failed attrs = %#v", failed.Attrs)
}
}
func receiveMCPRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
t.Helper()
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("runtime event channel closed before expected event")
}
return evt
case <-time.After(time.Second):
t.Fatal("timed out waiting for runtime event")
return runtimeevents.Event{}
}
}
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
mgr := NewManager()
+642
View File
@@ -0,0 +1,642 @@
package common
import (
"strconv"
"strings"
)
const maxGeminiSchemaDepth = 64
var geminiSupportedTypes = map[string]bool{
"array": true,
"boolean": true,
"integer": true,
"number": true,
"object": true,
"string": true,
}
// SanitizeSchemaForGoogle reduces a JSON Schema to the conservative subset
// accepted by Google/Gemini-style function declarations. It resolves local
// refs, collapses composition keywords like anyOf/oneOf/allOf, and strips
// advanced keywords that Gemini-compatible backends often reject.
func SanitizeSchemaForGoogle(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
sanitizer := geminiSchemaSanitizer{root: schema}
result := sanitizer.sanitizeNode(schema, nil, 0)
if len(result) == 0 {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
if _, hasProps := result["properties"]; hasProps {
result["type"] = "object"
}
return result
}
// SanitizeSchemaForGemini is kept as a compatibility alias for the original
// Google/Gemini sanitizer name.
func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
return SanitizeSchemaForGoogle(schema)
}
type geminiSchemaSanitizer struct {
root map[string]any
}
func (s geminiSchemaSanitizer) sanitizeNode(
node map[string]any,
refTrail map[string]struct{},
depth int,
) map[string]any {
if node == nil || depth > maxGeminiSchemaDepth {
return map[string]any{}
}
normalized := s.normalizeNode(node, refTrail, depth)
if len(normalized) == 0 {
return map[string]any{}
}
result := make(map[string]any)
if desc, ok := normalized["description"].(string); ok && strings.TrimSpace(desc) != "" {
result["description"] = desc
}
if schemaType := sanitizeGeminiSchemaType(normalized["type"]); schemaType != "" {
result["type"] = schemaType
}
if enumValues := sanitizeGeminiEnum(normalized["enum"]); len(enumValues) > 0 {
result["enum"] = enumValues
}
if propsRaw, ok := normalized["properties"].(map[string]any); ok {
props := make(map[string]any, len(propsRaw))
for name, rawProp := range propsRaw {
propSchema, ok := rawProp.(map[string]any)
if !ok {
continue
}
sanitizedProp := s.sanitizeNode(propSchema, refTrail, depth+1)
if len(sanitizedProp) == 0 {
sanitizedProp = map[string]any{}
}
props[name] = sanitizedProp
}
result["properties"] = props
result["type"] = "object"
if required := sanitizeGeminiRequired(normalized["required"], props); len(required) > 0 {
result["required"] = required
}
}
if itemsRaw, ok := normalized["items"].(map[string]any); ok {
items := s.sanitizeNode(itemsRaw, refTrail, depth+1)
if len(items) == 0 {
items = map[string]any{}
}
result["items"] = items
if _, hasType := result["type"]; !hasType {
result["type"] = "array"
}
}
return result
}
func (s geminiSchemaSanitizer) normalizeNode(
node map[string]any,
refTrail map[string]struct{},
depth int,
) map[string]any {
if node == nil || depth > maxGeminiSchemaDepth {
return map[string]any{}
}
normalized := cloneGeminiSchemaMap(node)
if ref, ok := normalized["$ref"].(string); ok {
delete(normalized, "$ref")
if _, seen := refTrail[ref]; !seen {
if target, ok := s.resolveLocalSchemaRef(ref); ok {
nextTrail := cloneRefTrail(refTrail)
nextTrail[ref] = struct{}{}
normalized = mergeGeminiSchemaMaps(
s.normalizeNode(target, nextTrail, depth+1),
normalized,
)
}
}
}
if rawAllOf, ok := normalized["allOf"]; ok {
delete(normalized, "allOf")
for _, part := range schemaSlice(rawAllOf) {
normalized = mergeGeminiSchemaMaps(
normalized,
s.normalizeNode(part, refTrail, depth+1),
)
}
}
if rawAnyOf, ok := normalized["anyOf"]; ok {
delete(normalized, "anyOf")
normalized = mergeGeminiSchemaMaps(
s.mergeUnionBranches(schemaSlice(rawAnyOf), refTrail, depth+1),
normalized,
)
}
if rawOneOf, ok := normalized["oneOf"]; ok {
delete(normalized, "oneOf")
normalized = mergeGeminiSchemaMaps(
s.mergeUnionBranches(schemaSlice(rawOneOf), refTrail, depth+1),
normalized,
)
}
return normalized
}
func (s geminiSchemaSanitizer) mergeUnionBranches(
branches []map[string]any,
refTrail map[string]struct{},
depth int,
) map[string]any {
if len(branches) == 0 {
return map[string]any{}
}
objectBranches := make([]map[string]any, 0, len(branches))
arrayBranches := make([]map[string]any, 0, len(branches))
nonNullBranches := make([]map[string]any, 0, len(branches))
sameType := ""
sameTypeConsistent := true
for _, branch := range branches {
normalized := s.normalizeNode(branch, refTrail, depth+1)
if len(normalized) == 0 {
continue
}
branchType := geminiSchemaBranchType(normalized["type"])
if branchType == "null" {
continue
}
nonNullBranches = append(nonNullBranches, normalized)
if sameType == "" {
sameType = branchType
} else if branchType != "" && branchType != sameType {
sameTypeConsistent = false
}
if branchType == "object" || hasSchemaProperties(normalized) {
objectBranches = append(objectBranches, normalized)
continue
}
if branchType == "array" || hasSchemaItems(normalized) {
arrayBranches = append(arrayBranches, normalized)
}
}
if len(nonNullBranches) == 0 {
return map[string]any{}
}
if len(objectBranches) > 0 {
return mergeUnionObjectSchemas(objectBranches)
}
if len(arrayBranches) == len(nonNullBranches) && len(arrayBranches) > 0 {
return mergeUnionArraySchemas(arrayBranches)
}
if sameTypeConsistent && sameType != "" {
merged := map[string]any{}
for _, branch := range nonNullBranches {
merged = mergeGeminiSchemaMaps(merged, branch)
}
return merged
}
best := nonNullBranches[0]
bestScore := geminiUnionBranchScore(best)
for _, branch := range nonNullBranches[1:] {
if score := geminiUnionBranchScore(branch); score > bestScore {
best = branch
bestScore = score
}
}
return cloneGeminiSchemaMap(best)
}
func (s geminiSchemaSanitizer) resolveLocalSchemaRef(ref string) (map[string]any, bool) {
if ref == "#" {
return s.root, true
}
if !strings.HasPrefix(ref, "#/") {
return nil, false
}
var current any = s.root
for _, rawToken := range strings.Split(strings.TrimPrefix(ref, "#/"), "/") {
token := strings.ReplaceAll(strings.ReplaceAll(rawToken, "~1", "/"), "~0", "~")
switch value := current.(type) {
case map[string]any:
next, ok := value[token]
if !ok {
return nil, false
}
current = next
case []any:
index, err := strconv.Atoi(token)
if err != nil || index < 0 || index >= len(value) {
return nil, false
}
current = value[index]
default:
return nil, false
}
}
resolved, ok := current.(map[string]any)
return resolved, ok
}
func mergeUnionObjectSchemas(branches []map[string]any) map[string]any {
merged := map[string]any{
"type": "object",
"properties": map[string]any{},
}
var commonRequired map[string]struct{}
var requiredOrder []string
for i, branch := range branches {
merged = mergeGeminiSchemaMaps(merged, branch)
required := requiredStrings(branch["required"])
if i == 0 {
commonRequired = make(map[string]struct{}, len(required))
requiredOrder = append(requiredOrder, required...)
for _, name := range required {
commonRequired[name] = struct{}{}
}
continue
}
current := make(map[string]struct{}, len(required))
for _, name := range required {
current[name] = struct{}{}
}
for name := range commonRequired {
if _, ok := current[name]; !ok {
delete(commonRequired, name)
}
}
}
if len(commonRequired) > 0 {
filtered := make([]string, 0, len(commonRequired))
for _, name := range requiredOrder {
if _, ok := commonRequired[name]; ok {
filtered = append(filtered, name)
}
}
if len(filtered) > 0 {
merged["required"] = filtered
}
} else {
delete(merged, "required")
}
return merged
}
func mergeUnionArraySchemas(branches []map[string]any) map[string]any {
merged := map[string]any{
"type": "array",
}
for _, branch := range branches {
merged = mergeGeminiSchemaMaps(merged, branch)
}
return merged
}
func mergeGeminiSchemaMaps(base map[string]any, overlay map[string]any) map[string]any {
if len(base) == 0 {
return cloneGeminiSchemaMap(overlay)
}
if len(overlay) == 0 {
return cloneGeminiSchemaMap(base)
}
result := cloneGeminiSchemaMap(base)
for key, value := range overlay {
switch key {
case "properties":
overlayProps, ok := value.(map[string]any)
if !ok {
continue
}
existing, _ := result["properties"].(map[string]any)
mergedProps := cloneGeminiSchemaMap(existing)
if mergedProps == nil {
mergedProps = make(map[string]any, len(overlayProps))
}
for name, rawProp := range overlayProps {
propSchema, ok := rawProp.(map[string]any)
if !ok {
continue
}
if existingProp, ok := mergedProps[name].(map[string]any); ok {
mergedProps[name] = mergeGeminiSchemaMaps(existingProp, propSchema)
} else {
mergedProps[name] = cloneGeminiSchemaMap(propSchema)
}
}
result["properties"] = mergedProps
case "items":
overlayItems, ok := value.(map[string]any)
if !ok {
continue
}
if existingItems, ok := result["items"].(map[string]any); ok {
result["items"] = mergeGeminiSchemaMaps(existingItems, overlayItems)
} else {
result["items"] = cloneGeminiSchemaMap(overlayItems)
}
case "required":
if merged := mergeRequiredLists(result["required"], value); len(merged) > 0 {
result["required"] = merged
}
case "type":
if mergedType := mergeGeminiSchemaTypes(result["type"], value); mergedType != "" {
result["type"] = mergedType
} else {
delete(result, "type")
}
case "description":
desc, ok := value.(string)
if ok && strings.TrimSpace(desc) != "" {
result["description"] = desc
}
default:
result[key] = cloneGeminiSchemaValue(value)
}
}
return result
}
func mergeGeminiSchemaTypes(left any, right any) string {
leftType := geminiSchemaBranchType(left)
rightType := geminiSchemaBranchType(right)
switch {
case leftType == "":
return rightType
case rightType == "":
return leftType
case leftType == rightType:
return leftType
case leftType == "null":
return rightType
case rightType == "null":
return leftType
default:
return ""
}
}
func sanitizeGeminiSchemaType(raw any) string {
typeName := geminiSchemaBranchType(raw)
if typeName == "null" {
return ""
}
return typeName
}
func geminiSchemaBranchType(raw any) string {
switch value := raw.(type) {
case string:
if value == "null" {
return value
}
if geminiSupportedTypes[value] {
return value
}
return ""
case []string:
return geminiSchemaBranchType(stringSliceToAny(value))
case []any:
candidate := ""
sawNull := false
for _, item := range value {
typeName, ok := item.(string)
if !ok {
continue
}
if typeName == "null" {
sawNull = true
continue
}
if !geminiSupportedTypes[typeName] {
continue
}
if candidate == "" {
candidate = typeName
continue
}
if candidate != typeName {
return ""
}
}
if candidate == "" && sawNull {
return "null"
}
return candidate
default:
return ""
}
}
func sanitizeGeminiEnum(raw any) []any {
values, ok := raw.([]any)
if !ok {
if stringValues, ok := raw.([]string); ok {
return stringSliceToAny(stringValues)
}
return nil
}
sanitized := make([]any, 0, len(values))
for _, value := range values {
switch value.(type) {
case string, bool, float64, int, int32, int64:
sanitized = append(sanitized, value)
}
}
if len(sanitized) == 0 {
return nil
}
return sanitized
}
func sanitizeGeminiRequired(raw any, properties map[string]any) []string {
required := requiredStrings(raw)
if len(required) == 0 {
return nil
}
filtered := make([]string, 0, len(required))
seen := make(map[string]struct{}, len(required))
for _, name := range required {
if _, ok := properties[name]; !ok {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
filtered = append(filtered, name)
}
if len(filtered) == 0 {
return nil
}
return filtered
}
func requiredStrings(raw any) []string {
switch value := raw.(type) {
case []string:
return append([]string(nil), value...)
case []any:
required := make([]string, 0, len(value))
for _, item := range value {
name, ok := item.(string)
if ok {
required = append(required, name)
}
}
return required
default:
return nil
}
}
func mergeRequiredLists(left any, right any) []string {
merged := make([]string, 0)
seen := map[string]struct{}{}
for _, name := range append(requiredStrings(left), requiredStrings(right)...) {
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
return merged
}
func geminiUnionBranchScore(schema map[string]any) int {
score := 0
if hasSchemaProperties(schema) {
score += 20
}
if hasSchemaItems(schema) {
score += 10
}
if _, ok := schema["enum"]; ok {
score += 5
}
if _, ok := schema["description"]; ok {
score += 2
}
score += len(schema)
return score
}
func hasSchemaProperties(schema map[string]any) bool {
props, ok := schema["properties"].(map[string]any)
return ok && len(props) > 0
}
func hasSchemaItems(schema map[string]any) bool {
_, ok := schema["items"].(map[string]any)
return ok
}
func schemaSlice(raw any) []map[string]any {
switch value := raw.(type) {
case []map[string]any:
return append([]map[string]any(nil), value...)
case []any:
schemas := make([]map[string]any, 0, len(value))
for _, item := range value {
schema, ok := item.(map[string]any)
if ok {
schemas = append(schemas, schema)
}
}
return schemas
default:
return nil
}
}
func cloneGeminiSchemaMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for key, value := range in {
out[key] = cloneGeminiSchemaValue(value)
}
return out
}
func cloneGeminiSchemaValue(value any) any {
switch typed := value.(type) {
case map[string]any:
return cloneGeminiSchemaMap(typed)
case []any:
out := make([]any, len(typed))
for i, item := range typed {
out[i] = cloneGeminiSchemaValue(item)
}
return out
case []string:
return append([]string(nil), typed...)
default:
return typed
}
}
func cloneRefTrail(in map[string]struct{}) map[string]struct{} {
if len(in) == 0 {
return make(map[string]struct{})
}
out := make(map[string]struct{}, len(in))
for key := range in {
out[key] = struct{}{}
}
return out
}
func stringSliceToAny(values []string) []any {
if len(values) == 0 {
return nil
}
result := make([]any, len(values))
for i, value := range values {
result[i] = value
}
return result
}
+254
View File
@@ -0,0 +1,254 @@
package common
import "testing"
func TestSanitizeSchemaForGemini_DereferencesRefsAndFlattensUnions(t *testing.T) {
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"parent": map[string]any{
"anyOf": []any{
map[string]any{"$ref": "#/$defs/pageParent"},
map[string]any{"$ref": "#/$defs/databaseParent"},
},
},
"icon": map[string]any{
"anyOf": []any{
map[string]any{"$ref": "#/$defs/emoji"},
map[string]any{"type": "null"},
},
},
"data": map[string]any{
"$ref": "#/$defs/dataPayload",
},
},
"required": []any{"parent", "icon", "missing"},
"$defs": map[string]any{
"pageParent": map[string]any{
"type": "object",
"properties": map[string]any{
"page_id": map[string]any{
"type": "string",
},
},
"required": []any{"page_id"},
},
"databaseParent": map[string]any{
"type": "object",
"properties": map[string]any{
"database_id": map[string]any{
"type": "string",
},
},
"required": []any{"database_id"},
},
"emoji": map[string]any{
"type": "string",
"pattern": "^:[a-z_]+:$",
},
"dataPayload": map[string]any{
"type": "object",
"additionalProperties": false,
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"minLength": 1,
},
"count": map[string]any{
"type": "integer",
"minimum": 1,
},
},
"required": []any{"name"},
},
},
}
got := SanitizeSchemaForGemini(schema)
assertSchemaKeyAbsent(t, got, "$defs")
assertSchemaKeyAbsent(t, got, "$ref")
assertSchemaKeyAbsent(t, got, "anyOf")
assertSchemaKeyAbsent(t, got, "oneOf")
assertSchemaKeyAbsent(t, got, "allOf")
assertSchemaKeyAbsent(t, got, "additionalProperties")
assertSchemaKeyAbsent(t, got, "pattern")
assertSchemaKeyAbsent(t, got, "minLength")
assertSchemaKeyAbsent(t, got, "minimum")
if got["type"] != "object" {
t.Fatalf("top-level type = %#v, want object", got["type"])
}
props, ok := got["properties"].(map[string]any)
if !ok {
t.Fatalf("properties = %#v, want map", got["properties"])
}
parent, ok := props["parent"].(map[string]any)
if !ok {
t.Fatalf("parent schema = %#v, want map", props["parent"])
}
if parent["type"] != "object" {
t.Fatalf("parent.type = %#v, want object", parent["type"])
}
parentProps, ok := parent["properties"].(map[string]any)
if !ok {
t.Fatalf("parent.properties = %#v, want map", parent["properties"])
}
if _, found := parentProps["page_id"]; !found {
t.Fatalf("parent.properties missing page_id: %#v", parentProps)
}
if _, found := parentProps["database_id"]; !found {
t.Fatalf("parent.properties missing database_id: %#v", parentProps)
}
if _, hasRequired := parent["required"]; hasRequired {
t.Fatalf("parent.required = %#v, want omitted for merged anyOf branches", parent["required"])
}
icon, ok := props["icon"].(map[string]any)
if !ok {
t.Fatalf("icon schema = %#v, want map", props["icon"])
}
if icon["type"] != "string" {
t.Fatalf("icon.type = %#v, want string", icon["type"])
}
data, ok := props["data"].(map[string]any)
if !ok {
t.Fatalf("data schema = %#v, want map", props["data"])
}
if data["type"] != "object" {
t.Fatalf("data.type = %#v, want object", data["type"])
}
dataProps, ok := data["properties"].(map[string]any)
if !ok {
t.Fatalf("data.properties = %#v, want map", data["properties"])
}
if _, found := dataProps["name"]; !found {
t.Fatalf("data.properties missing name: %#v", dataProps)
}
if _, found := dataProps["count"]; !found {
t.Fatalf("data.properties missing count: %#v", dataProps)
}
required, ok := got["required"].([]string)
if !ok {
t.Fatalf("required = %#v, want []string", got["required"])
}
if len(required) != 2 || required[0] != "parent" || required[1] != "icon" {
t.Fatalf("required = %#v, want [parent icon]", required)
}
}
func TestSanitizeSchemaForGemini_MergesAllOfAndFiltersRequired(t *testing.T) {
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"payload": map[string]any{
"allOf": []any{
map[string]any{
"type": "object",
"properties": map[string]any{
"id": map[string]any{
"type": "string",
},
},
"required": []any{"id"},
},
map[string]any{
"properties": map[string]any{
"name": map[string]any{
"type": "string",
},
"count": map[string]any{
"type": "integer",
"minimum": 1,
},
},
"required": []any{"name", "missing"},
},
},
},
},
}
got := SanitizeSchemaForGemini(schema)
props := got["properties"].(map[string]any)
payload := props["payload"].(map[string]any)
if payload["type"] != "object" {
t.Fatalf("payload.type = %#v, want object", payload["type"])
}
payloadProps, ok := payload["properties"].(map[string]any)
if !ok {
t.Fatalf("payload.properties = %#v, want map", payload["properties"])
}
for _, key := range []string{"id", "name", "count"} {
if _, found := payloadProps[key]; !found {
t.Fatalf("payload.properties missing %q: %#v", key, payloadProps)
}
}
required, ok := payload["required"].([]string)
if !ok {
t.Fatalf("payload.required = %#v, want []string", payload["required"])
}
if len(required) != 2 || required[0] != "id" || required[1] != "name" {
t.Fatalf("payload.required = %#v, want [id name]", required)
}
assertSchemaKeyAbsent(t, payload, "allOf")
assertSchemaKeyAbsent(t, payload, "minimum")
}
func TestSanitizeSchemaForGemini_HandlesRecursiveRefs(t *testing.T) {
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"tree": map[string]any{
"$ref": "#/$defs/node",
},
},
"$defs": map[string]any{
"node": map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
},
"child": map[string]any{
"$ref": "#/$defs/node",
},
},
},
},
}
got := SanitizeSchemaForGemini(schema)
props := got["properties"].(map[string]any)
tree := props["tree"].(map[string]any)
if tree["type"] != "object" {
t.Fatalf("tree.type = %#v, want object", tree["type"])
}
assertSchemaKeyAbsent(t, tree, "$ref")
}
func assertSchemaKeyAbsent(t *testing.T, value any, key string) {
t.Helper()
switch typed := value.(type) {
case map[string]any:
if _, found := typed[key]; found {
t.Fatalf("schema still contains key %q: %#v", key, typed)
}
for _, nested := range typed {
assertSchemaKeyAbsent(t, nested, key)
}
case []any:
for _, nested := range typed {
assertSchemaKeyAbsent(t, nested, key)
}
case []string:
return
}
}
@@ -0,0 +1,59 @@
package common
import (
"fmt"
"strings"
)
const (
ToolSchemaTransformOff = ""
ToolSchemaTransformSimple = "simple"
)
// NormalizeToolSchemaTransform resolves user-facing aliases to a canonical
// transform mode. Empty values and explicit "off"-style values disable schema
// transformation.
func NormalizeToolSchemaTransform(raw string) (string, error) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "", "off", "none", "native":
return ToolSchemaTransformOff, nil
case "simple", "basic", "strict", "flat":
return ToolSchemaTransformSimple, nil
default:
return "", fmt.Errorf("unsupported tool_schema_transform %q (supported: off, simple)", raw)
}
}
// TransformToolDefinitions clones tool definitions and applies the configured
// schema transform to function parameter schemas. When the transform is off, the
// original slice is returned unchanged.
func TransformToolDefinitions(tools []ToolDefinition, transform string) ([]ToolDefinition, error) {
transform, err := NormalizeToolSchemaTransform(transform)
if err != nil {
return nil, err
}
if transform == ToolSchemaTransformOff || len(tools) == 0 {
return tools, nil
}
out := make([]ToolDefinition, len(tools))
for i, tool := range tools {
out[i] = tool
if tool.Type != "function" {
continue
}
out[i].Function = tool.Function
out[i].Function.Parameters = transformToolSchema(tool.Function.Parameters, transform)
}
return out, nil
}
func transformToolSchema(schema map[string]any, transform string) map[string]any {
switch transform {
case ToolSchemaTransformSimple:
return SanitizeSchemaForGoogle(schema)
default:
return cloneGeminiSchemaMap(schema)
}
}
+47 -36
View File
@@ -110,19 +110,7 @@ func ExtractProtocol(cfg *config.ModelConfig) (protocol, modelID string) {
if provider := strings.TrimSpace(cfg.Provider); provider != "" {
return NormalizeProvider(provider), model
}
if model == "" {
return "", ""
}
protocol, rest, found := strings.Cut(model, "/")
if !found {
return "openai", model
}
protocol = strings.TrimSpace(protocol)
if protocol == "" {
return "", strings.TrimSpace(rest)
}
return NormalizeProvider(protocol), strings.TrimSpace(rest)
return SplitModelProviderAndID(model, "openai")
}
// ResolveAPIBase returns the configured API base, or the protocol default when
@@ -154,6 +142,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
}
protocol, modelID := ExtractProtocol(cfg)
authMethod := strings.ToLower(strings.TrimSpace(cfg.AuthMethod))
userAgent := cfg.UserAgent
if userAgent == "" {
@@ -163,12 +152,12 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
switch protocol {
case "openai":
// OpenAI with OAuth/token auth (Codex-style)
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
if authMethod == "oauth" || authMethod == "token" {
provider, err := createCodexAuthProvider()
if err != nil {
return nil, "", err
}
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
}
// OpenAI with API key
if cfg.APIKey() == "" && cfg.APIBase == "" {
@@ -189,7 +178,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
)
provider.SetProviderName(protocol)
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "azure", "azure-openai":
// Azure OpenAI uses deployment-based URLs, api-key header auth,
@@ -202,13 +191,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
)
}
return azure.NewProviderWithTimeout(
return finalizeProviderFromConfig(azure.NewProviderWithTimeout(
cfg.APIKey(),
cfg.APIBase,
cfg.Proxy,
userAgent,
cfg.RequestTimeout,
), modelID, nil
), modelID, cfg)
case "bedrock":
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
@@ -244,7 +233,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if err != nil {
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
}
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
@@ -270,7 +259,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
)
provider.SetProviderName(protocol)
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "gemini":
if cfg.APIKey() == "" && cfg.APIBase == "" {
@@ -280,7 +269,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
return NewGeminiProvider(
return finalizeProviderFromConfig(NewGeminiProvider(
cfg.APIKey(),
apiBase,
cfg.Proxy,
@@ -288,7 +277,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
cfg.ExtraBody,
cfg.CustomHeaders,
), modelID, nil
), modelID, cfg)
case "minimax":
// Minimax requires reasoning_split: true in the request body
@@ -317,16 +306,16 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
)
provider.SetProviderName(protocol)
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "anthropic":
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
if authMethod == "oauth" || authMethod == "token" {
// Use OAuth credentials from auth store
provider, err := createClaudeAuthProvider()
if err != nil {
return nil, "", err
}
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
}
// Use API key with HTTP API
apiBase := cfg.APIBase
@@ -347,7 +336,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.CustomHeaders,
)
provider.SetProviderName(protocol)
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
case "anthropic-messages":
// Anthropic Messages API with native format (HTTP-based, no SDK)
@@ -358,12 +347,12 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if cfg.APIKey() == "" {
return nil, "", fmt.Errorf("api_key is required for anthropic-messages protocol (model: %s)", cfg.Model)
}
return anthropicmessages.NewProviderWithTimeout(
return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout(
cfg.APIKey(),
apiBase,
userAgent,
cfg.RequestTimeout,
), modelID, nil
), modelID, cfg)
case "coding-plan-anthropic", "alibaba-coding-anthropic":
// Alibaba Coding Plan with Anthropic-compatible API
@@ -374,29 +363,29 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if cfg.APIKey() == "" {
return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model)
}
return anthropicmessages.NewProviderWithTimeout(
return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout(
cfg.APIKey(),
apiBase,
userAgent,
cfg.RequestTimeout,
), modelID, nil
), modelID, cfg)
case "antigravity":
return NewAntigravityProvider(), modelID, nil
return finalizeProviderFromConfig(NewAntigravityProvider(), modelID, cfg)
case "claude-cli", "claudecli":
workspace := cfg.Workspace
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), modelID, nil
return finalizeProviderFromConfig(NewClaudeCliProvider(workspace), modelID, cfg)
case "codex-cli", "codexcli":
workspace := cfg.Workspace
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), modelID, nil
return finalizeProviderFromConfig(NewCodexCliProvider(workspace), modelID, cfg)
case "github-copilot", "copilot":
apiBase := cfg.APIBase
@@ -411,15 +400,27 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if err != nil {
return nil, "", err
}
return provider, modelID, nil
return finalizeProviderFromConfig(provider, modelID, cfg)
default:
return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
}
}
func finalizeProviderFromConfig(
provider LLMProvider,
modelID string,
cfg *config.ModelConfig,
) (LLMProvider, string, error) {
wrapped, err := wrapProviderWithToolSchemaTransform(provider, cfg.ToolSchemaTransform)
if err != nil {
return nil, "", err
}
return wrapped, modelID, nil
}
func isEmptyAPIKeyAllowed(protocol string) bool {
meta, ok := protocolMetaByName[protocol]
meta, ok := protocolMetaForName(protocol)
return ok && meta.emptyAPIKeyAllowed
}
@@ -439,9 +440,19 @@ func DefaultAPIBaseForProtocol(protocol string) string {
// getDefaultAPIBase returns the default API base URL for a given protocol.
func getDefaultAPIBase(protocol string) string {
meta, ok := protocolMetaByName[protocol]
meta, ok := protocolMetaForName(protocol)
if !ok {
return ""
}
return meta.defaultAPIBase
}
func protocolMetaForName(protocol string) (protocolMeta, bool) {
if meta, ok := protocolMetaByName[protocol]; ok {
return meta, true
}
if meta, ok := attachedModelProviderMetaByName[protocol]; ok {
return meta.protocolMeta, true
}
return protocolMeta{}, false
}
+169 -2
View File
@@ -13,6 +13,7 @@ import (
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -101,6 +102,12 @@ func TestExtractProtocol(t *testing.T) {
wantProtocol: "",
wantModelID: "gpt-4o",
},
{
name: "unknown prefix falls back to openai",
config: &config.ModelConfig{Model: "meta-llama/Llama-3.1-8B-Instruct"},
wantProtocol: "openai",
wantModelID: "meta-llama/Llama-3.1-8B-Instruct",
},
{
name: "nil config",
wantProtocol: "",
@@ -605,6 +612,41 @@ func TestCreateProviderFromConfig_CodexCLI(t *testing.T) {
}
}
func TestCreateProviderFromConfig_OpenAIMixedCaseAuthMethodUsesOAuthBranch(t *testing.T) {
origGetCredential := getCredential
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "openai" {
t.Fatalf("provider = %q, want %q", provider, "openai")
}
return &auth.AuthCredential{
AccessToken: "test-token",
AccountID: "acct-test",
Provider: "openai",
AuthMethod: "oauth",
}, nil
}
t.Cleanup(func() {
getCredential = origGetCredential
})
cfg := &config.ModelConfig{
ModelName: "test-openai-oauth",
Model: "openai/gpt-5.4",
AuthMethod: "OAuth",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "gpt-5.4" {
t.Errorf("modelID = %q, want %q", modelID, "gpt-5.4")
}
}
func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-no-key",
@@ -619,8 +661,9 @@ func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) {
func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-unknown",
Model: "unknown-protocol/model",
ModelName: "test-unknown-provider",
Provider: "unknown-protocol",
Model: "model",
}
cfg.SetAPIKey("test-key")
@@ -630,6 +673,26 @@ func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
}
}
func TestCreateProviderFromConfig_UnknownModelPrefixDefaultsToOpenAI(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-unknown-model-prefix",
Model: "meta-llama/Llama-3.1-8B-Instruct",
APIBase: "https://api.example.com/v1",
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "meta-llama/Llama-3.1-8B-Instruct" {
t.Fatalf("modelID = %q, want full model ID", modelID)
}
}
func TestCreateProviderFromConfig_NilConfig(t *testing.T) {
_, _, err := CreateProviderFromConfig(nil)
if err == nil {
@@ -889,6 +952,71 @@ func TestGetDefaultAPIBase_QwenUSAliases(t *testing.T) {
}
}
func TestModelProviderOptions(t *testing.T) {
options := ModelProviderOptions()
if len(options) == 0 {
t.Fatal("ModelProviderOptions() returned no options")
}
seen := make(map[string]ModelProviderOption, len(options))
for _, option := range options {
seen[option.ID] = option
}
if _, ok := seen["openai"]; !ok {
t.Fatal("openai option missing")
}
if option, ok := seen["openai"]; ok && !option.CreateAllowed {
t.Fatal("openai should be creatable")
}
if option, ok := seen["lmstudio"]; !ok {
t.Fatal("lmstudio option missing")
} else if !option.EmptyAPIKeyAllowed {
t.Fatal("lmstudio should allow empty API keys")
}
if option, ok := seen["anthropic"]; !ok {
t.Fatal("anthropic option missing")
} else if option.DefaultAPIBase != "https://api.anthropic.com/v1" {
t.Fatalf("anthropic default_api_base = %q, want %q", option.DefaultAPIBase, "https://api.anthropic.com/v1")
}
if _, ok := seen["azure"]; !ok {
t.Fatal("azure option missing")
}
if option, ok := seen["bedrock"]; !ok {
t.Fatal("bedrock option missing")
} else if !option.CreateAllowed {
t.Fatal("bedrock should be creatable and defer credential/build errors to runtime")
}
if option, ok := seen["elevenlabs"]; !ok {
t.Fatal("elevenlabs option missing")
} else {
if option.DefaultAPIBase != "https://api.elevenlabs.io" {
t.Fatalf("elevenlabs default_api_base = %q, want %q", option.DefaultAPIBase, "https://api.elevenlabs.io")
}
if option.DefaultModelAllowed {
t.Fatal("elevenlabs should be ASR-only and therefore not allowed as a default chat model")
}
}
if option, ok := seen["antigravity"]; !ok {
t.Fatal("antigravity option missing")
} else {
if !option.CreateAllowed {
t.Fatal("antigravity should be creatable")
}
if option.DefaultAuthMethod != "oauth" {
t.Fatalf("antigravity default_auth_method = %q, want %q", option.DefaultAuthMethod, "oauth")
}
if !option.AuthMethodLocked {
t.Fatal("antigravity auth method should be locked")
}
}
if option, ok := seen["github-copilot"]; !ok {
t.Fatal("github-copilot option missing")
} else if option.DefaultAPIBase != "localhost:4321" {
t.Fatalf("github-copilot default_api_base = %q, want %q", option.DefaultAPIBase, "localhost:4321")
}
}
func TestCreateProviderFromConfig_MinimaxInjectsReasoningSplit(t *testing.T) {
var requestBody map[string]any
@@ -1202,3 +1330,42 @@ func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
// Unexpected error - fail the test
t.Errorf("unexpected error from bedrock provider: %v", err)
}
func TestCreateProviderFromConfig_ToolSchemaTransformWrapsProvider(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "claude-cli-test",
Provider: "claude-cli",
Model: "claude-sonnet-4.6",
Workspace: t.TempDir(),
ToolSchemaTransform: "simple",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if modelID != "claude-sonnet-4.6" {
t.Fatalf("modelID = %q, want %q", modelID, "claude-sonnet-4.6")
}
if _, ok := provider.(*toolSchemaTransformProvider); !ok {
t.Fatalf("provider = %T, want *toolSchemaTransformProvider", provider)
}
}
func TestCreateProviderFromConfig_InvalidToolSchemaTransform(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "claude-cli-test",
Provider: "claude-cli",
Model: "claude-sonnet-4.6",
Workspace: t.TempDir(),
ToolSchemaTransform: "invalid",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for invalid tool_schema_transform")
}
if !strings.Contains(err.Error(), "tool_schema_transform") {
t.Fatalf("error = %v, want mention tool_schema_transform", err)
}
}
-60
View File
@@ -12,66 +12,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake
return ""
}
var geminiUnsupportedKeywords = map[string]bool{
"patternProperties": true,
"additionalProperties": true,
"$schema": true,
"$id": true,
"$ref": true,
"$defs": true,
"definitions": true,
"examples": true,
"minLength": true,
"maxLength": true,
"minimum": true,
"maximum": true,
"multipleOf": true,
"pattern": true,
"format": true,
"minItems": true,
"maxItems": true,
"uniqueItems": true,
"minProperties": true,
"maxProperties": true,
}
func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
result := make(map[string]any)
for k, v := range schema {
if geminiUnsupportedKeywords[k] {
continue
}
switch val := v.(type) {
case map[string]any:
result[k] = sanitizeSchemaForGemini(val)
case []any:
sanitized := make([]any, len(val))
for i, item := range val {
if m, ok := item.(map[string]any); ok {
sanitized[i] = sanitizeSchemaForGemini(m)
} else {
sanitized[i] = item
}
}
result[k] = sanitized
default:
result[k] = v
}
}
if _, hasProps := result["properties"]; hasProps {
if _, hasType := result["type"]; !hasType {
result["type"] = "object"
}
}
return result
}
func extractProtocol(model string) (protocol, modelID string) {
model = strings.TrimSpace(model)
protocol, modelID, found := strings.Cut(model, "/")
+1 -1
View File
@@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody(
funcDecls = append(funcDecls, geminiFunctionDeclaration{
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: sanitizeSchemaForGemini(t.Function.Parameters),
Parameters: t.Function.Parameters,
})
}
if len(funcDecls) > 0 {
@@ -259,6 +259,64 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
}
}
func TestGeminiProvider_BuildRequestBody_PreservesComplexToolSchemasByDefault(t *testing.T) {
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"parent": map[string]any{
"anyOf": []any{
map[string]any{"$ref": "#/$defs/pageParent"},
map[string]any{"$ref": "#/$defs/databaseParent"},
},
},
},
"$defs": map[string]any{
"pageParent": map[string]any{
"type": "object",
"properties": map[string]any{
"page_id": map[string]any{"type": "string"},
},
"required": []any{"page_id"},
},
"databaseParent": map[string]any{
"type": "object",
"properties": map[string]any{
"database_id": map[string]any{"type": "string"},
},
"required": []any{"database_id"},
},
},
}
body := provider.buildRequestBody(
[]Message{{Role: "user", Content: "hello"}},
[]ToolDefinition{{
Type: "function",
Function: ToolFunctionDefinition{
Name: "mcp_notion_create",
Description: "Create a Notion object",
Parameters: schema,
},
}},
"gemini-3-flash-preview",
nil,
)
tools, ok := body["tools"].([]geminiTool)
if !ok || len(tools) != 1 {
t.Fatalf("tools = %#v, want one geminiTool", body["tools"])
}
got, ok := tools[0].FunctionDeclarations[0].Parameters.(map[string]any)
if !ok {
t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters)
}
if got["$defs"] == nil {
t.Fatalf("parameters = %#v, want raw schema with $defs preserved by default", got)
}
}
func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
+15 -10
View File
@@ -17,18 +17,13 @@ func ParseModelRef(raw string, defaultProvider string) *ModelRef {
return nil
}
if idx := strings.Index(raw, "/"); idx > 0 {
provider := NormalizeProvider(raw[:idx])
model := strings.TrimSpace(raw[idx+1:])
if model == "" {
return nil
}
return &ModelRef{Provider: provider, Model: model}
provider, model := SplitModelProviderAndID(raw, defaultProvider)
if model == "" {
return nil
}
return &ModelRef{
Provider: NormalizeProvider(defaultProvider),
Model: raw,
Provider: provider,
Model: model,
}
}
@@ -53,6 +48,8 @@ func NormalizeProvider(provider string) string {
return "zhipu"
case "google":
return "gemini"
case "google-antigravity":
return "antigravity"
case "alibaba-coding", "qwen-coding":
return "coding-plan"
case "alibaba-coding-anthropic":
@@ -61,6 +58,14 @@ func NormalizeProvider(provider string) string {
return "qwen-intl"
case "dashscope-us":
return "qwen-us"
case "azure-openai":
return "azure"
case "claudecli":
return "claude-cli"
case "codexcli":
return "codex-cli"
case "copilot":
return "github-copilot"
}
return p

Some files were not shown because too many files have changed in this diff Show More