mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(line): classify SDK errors with HTTP status and add client timeout
Address review feedback: - Use *WithHttpInfo SDK variants to get HTTP response status codes - Map status codes via ClassifySendError (429→ErrRateLimit, 5xx→ErrTemporary, 4xx→ErrSendFailed) - Fall back to ClassifyNetError for network-level failures - Configure SDK with 30s timeout HTTP client Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -43,3 +43,9 @@ func (a *channelManagerAdapter) SendMedia(ctx context.Context, msg bus.OutboundM
|
||||
func (a *channelManagerAdapter) SendPlaceholder(ctx context.Context, channel, chatID string) bool {
|
||||
return a.inner.SendPlaceholder(ctx, channel, chatID)
|
||||
}
|
||||
|
||||
func (a *channelManagerAdapter) DismissToolFeedback(
|
||||
ctx context.Context, channel, chatID string, outboundCtx *bus.InboundContext,
|
||||
) {
|
||||
a.inner.DismissToolFeedback(ctx, channel, chatID, outboundCtx)
|
||||
}
|
||||
|
||||
+40
-11
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -37,9 +38,13 @@ type AgentLoop struct {
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
|
||||
// Event system (from Incoming)
|
||||
eventBus *EventBus
|
||||
hooks *HookManager
|
||||
// Runtime event system
|
||||
runtimeEvents runtimeevents.Bus
|
||||
ownsRuntimeEvents bool
|
||||
runtimeEventLogMu sync.RWMutex
|
||||
runtimeEventLogger *runtimeEventLogger
|
||||
runtimeEventLogSub runtimeevents.Subscription
|
||||
hooks *HookManager
|
||||
|
||||
// Runtime state
|
||||
running atomic.Bool
|
||||
@@ -53,6 +58,7 @@ type AgentLoop struct {
|
||||
hookRuntime hookRuntime
|
||||
steering *steeringQueue
|
||||
pendingSkills sync.Map
|
||||
pendingStops sync.Map
|
||||
mu sync.RWMutex
|
||||
|
||||
// workerSem limits concurrent turn processing workers.
|
||||
@@ -172,6 +178,10 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
phase: TurnPhaseSetup,
|
||||
}
|
||||
if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded {
|
||||
if al.tryHandleStopCommand(ctx, msg, sessionKey) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Another turn is already active (or reserved) for this session — enqueue
|
||||
if err := al.enqueueSteeringMessage(sessionKey, agentID, providers.Message{
|
||||
Role: "user",
|
||||
@@ -235,6 +245,24 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
defer al.channelManager.InvokeTypingStop(m.Channel, m.ChatID)
|
||||
}
|
||||
|
||||
if al.takePendingStop(sessionKey) {
|
||||
al.activeTurnStates.Delete(sessionKey)
|
||||
target := &continuationTarget{
|
||||
SessionKey: sessionKey,
|
||||
Channel: m.Channel,
|
||||
ChatID: m.ChatID,
|
||||
}
|
||||
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
|
||||
if continueErr != nil {
|
||||
al.maybePublishError(ctx, m.Channel, m.ChatID, sessionKey, continueErr)
|
||||
return
|
||||
}
|
||||
if continued != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, continued)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
al.runTurnWithSteering(ctx, m)
|
||||
}(msg)
|
||||
|
||||
@@ -285,8 +313,14 @@ func (al *AgentLoop) Close() {
|
||||
if al.hooks != nil {
|
||||
al.hooks.Close()
|
||||
}
|
||||
if al.eventBus != nil {
|
||||
al.eventBus.Close()
|
||||
al.closeRuntimeEventLogger()
|
||||
if al.runtimeEvents != nil && al.ownsRuntimeEvents {
|
||||
if err := al.runtimeEvents.Close(); err != nil {
|
||||
logger.ErrorCF("agent", "Failed to close runtime event bus",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -294,12 +328,6 @@ func (al *AgentLoop) Close() {
|
||||
|
||||
// UnmountHook removes a previously registered in-process hook.
|
||||
|
||||
// SubscribeEvents registers a subscriber for agent-loop events.
|
||||
|
||||
// UnsubscribeEvents removes a previously registered event subscriber.
|
||||
|
||||
// EventDrops returns the number of dropped events for the given kind.
|
||||
|
||||
type turnEventScope struct {
|
||||
agentID string
|
||||
sessionKey string
|
||||
@@ -384,6 +412,7 @@ func (al *AgentLoop) ReloadProviderAndConfig(
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), newRL)
|
||||
|
||||
al.mu.Unlock()
|
||||
al.refreshRuntimeEventLogger(cfg)
|
||||
|
||||
oldMCPManager := al.mcp.reset()
|
||||
al.hookRuntime.reset(al)
|
||||
|
||||
@@ -274,6 +274,12 @@ func (al *AgentLoop) buildCommandsRuntime(
|
||||
return nil
|
||||
},
|
||||
}
|
||||
rt.StopActiveTurn = func() (commands.StopResult, error) {
|
||||
if opts == nil {
|
||||
return commands.StopResult{}, fmt.Errorf("process options not available")
|
||||
}
|
||||
return al.stopActiveTurnForSession(opts.Dispatch.SessionKey)
|
||||
}
|
||||
if agent != nil && agent.ContextBuilder != nil {
|
||||
rt.ListSkillNames = agent.ContextBuilder.ListSkillNames
|
||||
}
|
||||
|
||||
+31
-128
@@ -5,7 +5,7 @@ package agent
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func (al *AgentLoop) newTurnEventScope(agentID, sessionKey string, turnCtx *TurnContext) turnEventScope {
|
||||
@@ -18,8 +18,8 @@ func (al *AgentLoop) newTurnEventScope(agentID, sessionKey string, turnCtx *Turn
|
||||
}
|
||||
}
|
||||
|
||||
func (ts turnEventScope) meta(iteration int, source, tracePath string) EventMeta {
|
||||
return EventMeta{
|
||||
func (ts turnEventScope) meta(iteration int, source, tracePath string) HookMeta {
|
||||
return HookMeta{
|
||||
AgentID: ts.agentID,
|
||||
TurnID: ts.turnID,
|
||||
SessionKey: ts.sessionKey,
|
||||
@@ -30,119 +30,24 @@ func (ts turnEventScope) meta(iteration int, source, tracePath string) EventMeta
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) emitEvent(kind EventKind, meta EventMeta, payload any) {
|
||||
clonedMeta := cloneEventMeta(meta)
|
||||
evt := Event{
|
||||
Kind: kind,
|
||||
Meta: clonedMeta,
|
||||
Context: cloneTurnContext(clonedMeta.turnContext),
|
||||
Payload: payload,
|
||||
func (al *AgentLoop) emitEvent(kind runtimeevents.Kind, meta HookMeta, payload any) {
|
||||
clonedMeta := cloneHookMeta(meta)
|
||||
eventCtx := cloneTurnContext(clonedMeta.turnContext)
|
||||
evt := runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "agent", Name: clonedMeta.AgentID},
|
||||
Scope: runtimeScopeFromHookMeta(clonedMeta, eventCtx),
|
||||
Correlation: runtimeCorrelationFromHookMeta(clonedMeta),
|
||||
Severity: runtimeSeverityForAgentEvent(kind, payload),
|
||||
Payload: payload,
|
||||
Attrs: runtimeAttrsFromHookMeta(clonedMeta),
|
||||
}
|
||||
|
||||
if al == nil || al.eventBus == nil {
|
||||
if al == nil {
|
||||
return
|
||||
}
|
||||
|
||||
al.logEvent(evt)
|
||||
|
||||
al.eventBus.Emit(evt)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) logEvent(evt Event) {
|
||||
fields := map[string]any{
|
||||
"event_kind": evt.Kind.String(),
|
||||
"agent_id": evt.Meta.AgentID,
|
||||
"turn_id": evt.Meta.TurnID,
|
||||
"session_key": evt.Meta.SessionKey,
|
||||
"iteration": evt.Meta.Iteration,
|
||||
}
|
||||
|
||||
if evt.Meta.TracePath != "" {
|
||||
fields["trace"] = evt.Meta.TracePath
|
||||
}
|
||||
if evt.Meta.Source != "" {
|
||||
fields["source"] = evt.Meta.Source
|
||||
}
|
||||
|
||||
appendEventContextFields(fields, evt.Context)
|
||||
|
||||
switch payload := evt.Payload.(type) {
|
||||
case TurnStartPayload:
|
||||
fields["user_len"] = len(payload.UserMessage)
|
||||
fields["media_count"] = payload.MediaCount
|
||||
case TurnEndPayload:
|
||||
fields["status"] = payload.Status
|
||||
fields["iterations_total"] = payload.Iterations
|
||||
fields["duration_ms"] = payload.Duration.Milliseconds()
|
||||
fields["final_len"] = payload.FinalContentLen
|
||||
case LLMRequestPayload:
|
||||
fields["model"] = payload.Model
|
||||
fields["messages"] = payload.MessagesCount
|
||||
fields["tools"] = payload.ToolsCount
|
||||
fields["max_tokens"] = payload.MaxTokens
|
||||
case LLMDeltaPayload:
|
||||
fields["content_delta_len"] = payload.ContentDeltaLen
|
||||
fields["reasoning_delta_len"] = payload.ReasoningDeltaLen
|
||||
case LLMResponsePayload:
|
||||
fields["content_len"] = payload.ContentLen
|
||||
fields["tool_calls"] = payload.ToolCalls
|
||||
fields["has_reasoning"] = payload.HasReasoning
|
||||
case LLMRetryPayload:
|
||||
fields["attempt"] = payload.Attempt
|
||||
fields["max_retries"] = payload.MaxRetries
|
||||
fields["reason"] = payload.Reason
|
||||
fields["error"] = payload.Error
|
||||
fields["backoff_ms"] = payload.Backoff.Milliseconds()
|
||||
case ContextCompressPayload:
|
||||
fields["reason"] = payload.Reason
|
||||
fields["dropped_messages"] = payload.DroppedMessages
|
||||
fields["remaining_messages"] = payload.RemainingMessages
|
||||
case SessionSummarizePayload:
|
||||
fields["summarized_messages"] = payload.SummarizedMessages
|
||||
fields["kept_messages"] = payload.KeptMessages
|
||||
fields["summary_len"] = payload.SummaryLen
|
||||
fields["omitted_oversized"] = payload.OmittedOversized
|
||||
case ToolExecStartPayload:
|
||||
fields["tool"] = payload.Tool
|
||||
fields["args_count"] = len(payload.Arguments)
|
||||
case ToolExecEndPayload:
|
||||
fields["tool"] = payload.Tool
|
||||
fields["duration_ms"] = payload.Duration.Milliseconds()
|
||||
fields["for_llm_len"] = payload.ForLLMLen
|
||||
fields["for_user_len"] = payload.ForUserLen
|
||||
fields["is_error"] = payload.IsError
|
||||
fields["async"] = payload.Async
|
||||
case ToolExecSkippedPayload:
|
||||
fields["tool"] = payload.Tool
|
||||
fields["reason"] = payload.Reason
|
||||
case SteeringInjectedPayload:
|
||||
fields["count"] = payload.Count
|
||||
fields["total_content_len"] = payload.TotalContentLen
|
||||
case FollowUpQueuedPayload:
|
||||
fields["source_tool"] = payload.SourceTool
|
||||
fields["content_len"] = payload.ContentLen
|
||||
case InterruptReceivedPayload:
|
||||
fields["interrupt_kind"] = payload.Kind
|
||||
fields["role"] = payload.Role
|
||||
fields["content_len"] = payload.ContentLen
|
||||
fields["queue_depth"] = payload.QueueDepth
|
||||
fields["hint_len"] = payload.HintLen
|
||||
case SubTurnSpawnPayload:
|
||||
fields["child_agent_id"] = payload.AgentID
|
||||
fields["label"] = payload.Label
|
||||
case SubTurnEndPayload:
|
||||
fields["child_agent_id"] = payload.AgentID
|
||||
fields["status"] = payload.Status
|
||||
case SubTurnResultDeliveredPayload:
|
||||
fields["target_channel"] = payload.TargetChannel
|
||||
fields["target_chat_id"] = payload.TargetChatID
|
||||
fields["content_len"] = payload.ContentLen
|
||||
case ErrorPayload:
|
||||
fields["stage"] = payload.Stage
|
||||
fields["error"] = payload.Message
|
||||
}
|
||||
|
||||
logger.InfoCF("eventbus", fmt.Sprintf("Agent event: %s", evt.Kind.String()), fields)
|
||||
al.publishRuntimeEvent(evt)
|
||||
}
|
||||
|
||||
// MountHook registers an in-process hook on the agent loop.
|
||||
@@ -161,28 +66,26 @@ func (al *AgentLoop) UnmountHook(name string) {
|
||||
al.hooks.Unmount(name)
|
||||
}
|
||||
|
||||
// SubscribeEvents registers a subscriber for agent-loop events.
|
||||
func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription {
|
||||
if al == nil || al.eventBus == nil {
|
||||
ch := make(chan Event)
|
||||
close(ch)
|
||||
return EventSubscription{C: ch}
|
||||
// RuntimeEvents returns the root runtime event channel.
|
||||
func (al *AgentLoop) RuntimeEvents() runtimeevents.EventChannel {
|
||||
if al == nil || al.runtimeEvents == nil {
|
||||
return nil
|
||||
}
|
||||
return al.eventBus.Subscribe(buffer)
|
||||
return al.runtimeEvents.Channel()
|
||||
}
|
||||
|
||||
// UnsubscribeEvents removes a previously registered event subscriber.
|
||||
func (al *AgentLoop) UnsubscribeEvents(id uint64) {
|
||||
if al == nil || al.eventBus == nil {
|
||||
return
|
||||
// RuntimeEventStats returns runtime event bus counters.
|
||||
func (al *AgentLoop) RuntimeEventStats() runtimeevents.Stats {
|
||||
if al == nil || al.runtimeEvents == nil {
|
||||
return runtimeevents.Stats{Closed: true}
|
||||
}
|
||||
al.eventBus.Unsubscribe(id)
|
||||
return al.runtimeEvents.Stats()
|
||||
}
|
||||
|
||||
// EventDrops returns the number of dropped events for the given kind.
|
||||
func (al *AgentLoop) EventDrops(kind EventKind) int64 {
|
||||
if al == nil || al.eventBus == nil {
|
||||
return 0
|
||||
// RuntimeEventBus returns the runtime event bus used by the agent loop.
|
||||
func (al *AgentLoop) RuntimeEventBus() runtimeevents.Bus {
|
||||
if al == nil {
|
||||
return nil
|
||||
}
|
||||
return al.eventBus.Dropped(kind)
|
||||
return al.runtimeEvents
|
||||
}
|
||||
|
||||
+40
-12
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
@@ -24,6 +25,7 @@ func NewAgentLoop(
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
provider providers.LLMProvider,
|
||||
opts ...AgentLoopOption,
|
||||
) *AgentLoop {
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
@@ -47,8 +49,6 @@ func NewAgentLoop(
|
||||
stateManager = state.NewManager(defaultAgent.Workspace)
|
||||
}
|
||||
|
||||
eventBus := NewEventBus()
|
||||
|
||||
// Determine worker pool size from config (default: 1 = sequential)
|
||||
workerPoolSize := cfg.Agents.Defaults.MaxParallelTurns
|
||||
if workerPoolSize <= 0 {
|
||||
@@ -56,18 +56,28 @@ func NewAgentLoop(
|
||||
}
|
||||
|
||||
al := &AgentLoop{
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
eventBus: eventBus,
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
workerSem: make(chan struct{}, workerPoolSize),
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
workerSem: make(chan struct{}, workerPoolSize),
|
||||
ownsRuntimeEvents: true,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(al)
|
||||
}
|
||||
}
|
||||
if al.runtimeEvents == nil {
|
||||
al.runtimeEvents = runtimeevents.NewBus()
|
||||
al.ownsRuntimeEvents = true
|
||||
}
|
||||
al.refreshRuntimeEventLogger(cfg)
|
||||
al.providerFactory = providers.CreateProviderFromConfig
|
||||
al.hooks = NewHookManager(eventBus)
|
||||
al.hooks = NewHookManager(al.runtimeEvents.Channel())
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
al.contextManager = al.resolveContextManager()
|
||||
|
||||
@@ -128,6 +138,9 @@ func registerSharedTools(
|
||||
if cfg.Tools.IsToolEnabled("spi") {
|
||||
agent.Tools.Register(tools.NewSPITool())
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("serial") {
|
||||
agent.Tools.Register(tools.NewSerialTool())
|
||||
}
|
||||
|
||||
// Message tool
|
||||
if cfg.Tools.IsToolEnabled("message") {
|
||||
@@ -324,5 +337,20 @@ func registerSharedTools(
|
||||
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
|
||||
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
|
||||
}
|
||||
|
||||
// Register delegate tool for multi-agent setups.
|
||||
// Auto-enabled when multiple agents exist. Delegation uses the SubTurn
|
||||
// mechanism directly (not SubagentManager) and is independent of the
|
||||
// subagent tool.
|
||||
if len(registry.ListAgentIDs()) > 1 {
|
||||
delegateTool := tools.NewDelegateTool()
|
||||
delegateTool.SetSpawner(NewSubTurnSpawner(al))
|
||||
currentAgentID := agentID
|
||||
delegateTool.SetSelfAgentID(currentAgentID)
|
||||
delegateTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
|
||||
})
|
||||
agent.Tools.Register(delegateTool)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
al.mcp.initOnce.Do(func() {
|
||||
mcpManager := mcp.NewManager()
|
||||
mcpManager := mcp.NewManager(mcp.WithRuntimeEvents(al.runtimeEvents))
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
workspacePath := al.cfg.WorkspacePath()
|
||||
@@ -164,6 +164,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
mcpTool.SetWorkspace(agent.Workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
|
||||
mcpTool.SetEventPublisher(al.runtimeEvents)
|
||||
|
||||
if registerAsHidden {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
|
||||
+134
-63
@@ -11,6 +11,7 @@ import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
@@ -20,24 +21,59 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// genericPlaceholderRegex matches generic media placeholders emitted by various
|
||||
// channels: [image], [image: photo], [image: filename.jpg] — but NOT path tags
|
||||
// like [image:/path/to/file] (path tags have no space after the colon).
|
||||
var (
|
||||
imagePlaceholderRegex = regexp.MustCompile(`\[image(:\s+[^\]]*)?\]`)
|
||||
audioPlaceholderRegex = regexp.MustCompile(`\[audio(:\s+[^\]]*)?\]`)
|
||||
videoPlaceholderRegex = regexp.MustCompile(`\[video(:\s+[^\]]*)?\]`)
|
||||
filePlaceholderRegex = regexp.MustCompile(`\[file(:\s+[^\]]*)?\]`)
|
||||
)
|
||||
|
||||
// resolveMediaRefs resolves media:// refs in messages.
|
||||
// Images are base64-encoded into the Media array for multimodal LLMs.
|
||||
// Non-image files (documents, audio, video) have their local path injected
|
||||
// into Content so the agent can access them via file tools like read_file.
|
||||
// For user messages: images get path tags only ([image:/path]) so the LLM
|
||||
// can decide whether to view them via load_image or operate on the file.
|
||||
// For tool messages: images are base64-encoded and appended as a synthetic
|
||||
// user message only after the contiguous tool-message block ends, so we don't
|
||||
// break the tool-results-must-immediately-follow-assistant constraint that
|
||||
// LLM APIs enforce.
|
||||
// Non-image files always get path tags regardless of role.
|
||||
// Returns a new slice; original messages are not mutated.
|
||||
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
|
||||
if store == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
result := make([]providers.Message, len(messages))
|
||||
copy(result, messages)
|
||||
result := make([]providers.Message, 0, len(messages))
|
||||
var pendingToolImages []string
|
||||
|
||||
for idx, m := range messages {
|
||||
// When leaving a tool-message block, flush any accumulated images
|
||||
// as a synthetic user message.
|
||||
if m.Role != "tool" && len(pendingToolImages) > 0 {
|
||||
result = append(result, providers.Message{
|
||||
Role: "user",
|
||||
Content: "[Loaded image from tool result above]",
|
||||
Media: pendingToolImages,
|
||||
})
|
||||
pendingToolImages = nil
|
||||
}
|
||||
|
||||
for i, m := range result {
|
||||
if len(m.Media) == 0 {
|
||||
result = append(result, m)
|
||||
if idx == len(messages)-1 && len(pendingToolImages) > 0 {
|
||||
result = append(result, providers.Message{
|
||||
Role: "user",
|
||||
Content: "[Loaded image from tool result above]",
|
||||
Media: pendingToolImages,
|
||||
})
|
||||
pendingToolImages = nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
msg := m
|
||||
resolved := make([]string, 0, len(m.Media))
|
||||
var pathTags []string
|
||||
|
||||
@@ -66,27 +102,77 @@ func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxS
|
||||
}
|
||||
|
||||
mime := detectMIME(localPath, meta)
|
||||
pathTags = append(pathTags, buildPathTag(mime, localPath))
|
||||
|
||||
if strings.HasPrefix(mime, "image/") {
|
||||
if m.Role == "tool" && strings.HasPrefix(mime, "image/") {
|
||||
dataURL := encodeImageToDataURL(localPath, mime, info, maxSize)
|
||||
if dataURL != "" {
|
||||
resolved = append(resolved, dataURL)
|
||||
pendingToolImages = append(pendingToolImages, dataURL)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pathTags = append(pathTags, buildPathTag(mime, localPath))
|
||||
}
|
||||
|
||||
result[i].Media = resolved
|
||||
msg.Media = resolved
|
||||
if len(pathTags) > 0 {
|
||||
result[i].Content = injectPathTags(result[i].Content, pathTags)
|
||||
msg.Content = injectPathTags(msg.Content, pathTags)
|
||||
}
|
||||
result = append(result, msg)
|
||||
|
||||
// If this is the last message and we have pending images, flush them.
|
||||
if idx == len(messages)-1 && len(pendingToolImages) > 0 {
|
||||
result = append(result, providers.Message{
|
||||
Role: "user",
|
||||
Content: "[Loaded image from tool result above]",
|
||||
Media: pendingToolImages,
|
||||
})
|
||||
pendingToolImages = nil
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeImageToDataURL base64-encodes an image file into a data URL.
|
||||
// Returns empty string if the file exceeds maxSize or encoding fails.
|
||||
func encodeImageToDataURL(localPath, mime string, info os.FileInfo, maxSize int) string {
|
||||
if info.Size() > int64(maxSize) {
|
||||
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
|
||||
"path": localPath,
|
||||
"size": info.Size(),
|
||||
"max_size": maxSize,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
f, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to open media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
prefix := "data:" + mime + ";base64,"
|
||||
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(len(prefix) + encodedLen)
|
||||
buf.WriteString(prefix)
|
||||
|
||||
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
|
||||
if _, err := io.Copy(encoder, f); err != nil {
|
||||
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
encoder.Close()
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func buildArtifactTags(store media.MediaStore, refs []string) []string {
|
||||
if store == nil || len(refs) == 0 {
|
||||
return nil
|
||||
@@ -137,51 +223,12 @@ func detectMIME(localPath string, meta media.MediaMeta) string {
|
||||
return kind.MIME.Value
|
||||
}
|
||||
|
||||
// encodeImageToDataURL base64-encodes an image file into a data URL.
|
||||
// Returns empty string if the file exceeds maxSize or encoding fails.
|
||||
func encodeImageToDataURL(localPath, mime string, info os.FileInfo, maxSize int) string {
|
||||
if info.Size() > int64(maxSize) {
|
||||
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
|
||||
"path": localPath,
|
||||
"size": info.Size(),
|
||||
"max_size": maxSize,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
f, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to open media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
prefix := "data:" + mime + ";base64,"
|
||||
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(len(prefix) + encodedLen)
|
||||
buf.WriteString(prefix)
|
||||
|
||||
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
|
||||
if _, err := io.Copy(encoder, f); err != nil {
|
||||
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
encoder.Close()
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// buildPathTag creates a structured tag exposing the local file path.
|
||||
// Tag type is derived from MIME: [audio:/path], [video:/path], or [file:/path].
|
||||
// Tag type is derived from MIME: [image:/path], [audio:/path], [video:/path], or [file:/path].
|
||||
func buildPathTag(mime, localPath string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(mime, "image/"):
|
||||
return "[image:" + localPath + "]"
|
||||
case strings.HasPrefix(mime, "audio/"):
|
||||
return "[audio:" + localPath + "]"
|
||||
case strings.HasPrefix(mime, "video/"):
|
||||
@@ -192,22 +239,41 @@ func buildPathTag(mime, localPath string) string {
|
||||
}
|
||||
|
||||
// injectPathTags replaces generic media tags in content with path-bearing versions,
|
||||
// or appends if no matching generic tag is found.
|
||||
// or appends if no matching generic tag is found. Channels emit a few different
|
||||
// placeholder formats — [image], [image: photo], [image: filename.jpg] — so we
|
||||
// match all of them via regex while leaving path tags ([image:/path]) untouched.
|
||||
//
|
||||
// When content is structured data (e.g., JSON from Feishu interactive cards or
|
||||
// post messages), tags are only injected via placeholder replacement — never
|
||||
// appended — to avoid corrupting the payload.
|
||||
func injectPathTags(content string, tags []string) string {
|
||||
isStructured := looksLikeJSON(content)
|
||||
for _, tag := range tags {
|
||||
var generic string
|
||||
var pattern *regexp.Regexp
|
||||
switch {
|
||||
case strings.HasPrefix(tag, "[image:"):
|
||||
pattern = imagePlaceholderRegex
|
||||
case strings.HasPrefix(tag, "[audio:"):
|
||||
generic = "[audio]"
|
||||
pattern = audioPlaceholderRegex
|
||||
case strings.HasPrefix(tag, "[video:"):
|
||||
generic = "[video]"
|
||||
pattern = videoPlaceholderRegex
|
||||
case strings.HasPrefix(tag, "[file:"):
|
||||
generic = "[file]"
|
||||
pattern = filePlaceholderRegex
|
||||
}
|
||||
|
||||
if generic != "" && strings.Contains(content, generic) {
|
||||
content = strings.Replace(content, generic, tag, 1)
|
||||
} else if content == "" {
|
||||
if pattern != nil {
|
||||
if loc := pattern.FindStringIndex(content); loc != nil {
|
||||
content = content[:loc[0]] + tag + content[loc[1]:]
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if isStructured {
|
||||
content = tag + "\n" + content
|
||||
continue
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
content = tag
|
||||
} else {
|
||||
content += " " + tag
|
||||
@@ -215,3 +281,8 @@ func injectPathTags(content string, tags []string) string {
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func looksLikeJSON(s string) bool {
|
||||
s = strings.TrimSpace(s)
|
||||
return len(s) > 1 && s[0] == '{'
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package agent
|
||||
|
||||
import runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
|
||||
// AgentLoopOption configures an AgentLoop at construction time.
|
||||
type AgentLoopOption func(*AgentLoop)
|
||||
|
||||
// WithRuntimeEvents injects the runtime event bus used for new observation APIs.
|
||||
//
|
||||
// The injected bus is treated as externally owned and will not be closed by
|
||||
// AgentLoop.Close. Passing nil leaves the default owned runtime bus enabled.
|
||||
func WithRuntimeEvents(bus runtimeevents.Bus) AgentLoopOption {
|
||||
return func(al *AgentLoop) {
|
||||
if bus == nil {
|
||||
return
|
||||
}
|
||||
al.runtimeEvents = bus
|
||||
al.ownsRuntimeEvents = false
|
||||
}
|
||||
}
|
||||
+31
-15
@@ -44,11 +44,36 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
return
|
||||
}
|
||||
|
||||
// Drain steering queue using existing Continue mechanism
|
||||
continued, continueErr := al.drainQueuedSteeringContinuations(ctx, target)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF("agent", "Failed to continue queued steering",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"error": continueErr.Error(),
|
||||
})
|
||||
} else if continued != "" {
|
||||
finalResponse = continued
|
||||
}
|
||||
|
||||
// Publish final response
|
||||
if finalResponse != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) drainQueuedSteeringContinuations(
|
||||
ctx context.Context,
|
||||
target *continuationTarget,
|
||||
) (string, error) {
|
||||
if target == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
finalResponse := ""
|
||||
for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
|
||||
// Check for context cancellation between iterations
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
if err := ctx.Err(); err != nil {
|
||||
return finalResponse, err
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Continuing queued steering after turn end",
|
||||
@@ -61,13 +86,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
|
||||
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
|
||||
if continueErr != nil {
|
||||
logger.WarnCF("agent", "Failed to continue queued steering",
|
||||
map[string]any{
|
||||
"channel": target.Channel,
|
||||
"chat_id": target.ChatID,
|
||||
"error": continueErr.Error(),
|
||||
})
|
||||
break
|
||||
return finalResponse, continueErr
|
||||
}
|
||||
if continued == "" {
|
||||
break
|
||||
@@ -75,10 +94,7 @@ func (al *AgentLoop) runTurnWithSteering(ctx context.Context, initialMsg bus.Inb
|
||||
finalResponse = continued
|
||||
}
|
||||
|
||||
// Publish final response
|
||||
if finalResponse != "" {
|
||||
al.PublishResponseIfNeeded(ctx, target.Channel, target.ChatID, target.SessionKey, finalResponse)
|
||||
}
|
||||
return finalResponse, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
func (al *AgentLoop) tryHandleStopCommand(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
sessionKey string,
|
||||
) bool {
|
||||
cmdName, ok := commands.CommandName(msg.Content)
|
||||
if !ok || cmdName != "stop" {
|
||||
return false
|
||||
}
|
||||
|
||||
result, err := al.stopActiveTurnForSession(sessionKey)
|
||||
|
||||
// This function is only called when loaded=true (another turn already
|
||||
// claimed this session). If stopActiveTurnForSession found a pending
|
||||
// placeholder but didn't stop it, that placeholder belongs to the other
|
||||
// message's worker which hasn't started yet — arm a pending stop so the
|
||||
// worker will bail when it checks before running.
|
||||
if err == nil && !result.Stopped {
|
||||
if ts := al.getActiveTurnState(sessionKey); ts != nil {
|
||||
snap := ts.snapshot()
|
||||
if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) {
|
||||
al.markPendingStop(sessionKey)
|
||||
result.Stopped = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reply := commands.FormatStopReply(result)
|
||||
if err != nil {
|
||||
reply = "Failed to stop task: " + err.Error()
|
||||
}
|
||||
|
||||
if al.channelManager != nil {
|
||||
al.channelManager.InvokeTypingStop(msg.Channel, msg.ChatID)
|
||||
}
|
||||
al.resetMessageToolRound(sessionKey)
|
||||
al.PublishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, sessionKey, reply)
|
||||
return true
|
||||
}
|
||||
|
||||
func (al *AgentLoop) stopActiveTurnForSession(sessionKey string) (commands.StopResult, error) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return commands.StopResult{}, fmt.Errorf("session key is required")
|
||||
}
|
||||
|
||||
result := commands.StopResult{}
|
||||
cleared := al.clearSteeringMessagesForScope(sessionKey)
|
||||
al.clearPendingSkills(sessionKey)
|
||||
|
||||
ts := al.getActiveTurnState(sessionKey)
|
||||
if ts == nil {
|
||||
result.Stopped = cleared > 0
|
||||
return result, nil
|
||||
}
|
||||
|
||||
snap := ts.snapshot()
|
||||
result.TaskName = snap.UserMessage
|
||||
|
||||
if strings.HasPrefix(snap.TurnID, pendingTurnPrefix) {
|
||||
// A pending placeholder means this session is either idle (our own
|
||||
// placeholder from the /stop command) or another message is queued but
|
||||
// hasn't started yet. In both cases, we don't arm a pending stop here;
|
||||
// the caller (tryHandleStopCommand) handles the "another message queued"
|
||||
// case explicitly, since it knows loaded=true.
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if err := al.HardAbort(sessionKey); err != nil {
|
||||
if al.getActiveTurnState(sessionKey) == nil {
|
||||
result.Stopped = cleared > 0
|
||||
return result, nil
|
||||
}
|
||||
return commands.StopResult{}, err
|
||||
}
|
||||
|
||||
result.Stopped = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) markPendingStop(sessionKey string) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return
|
||||
}
|
||||
al.pendingStops.Store(sessionKey, struct{}{})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) takePendingStop(sessionKey string) bool {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return false
|
||||
}
|
||||
_, ok := al.pendingStops.LoadAndDelete(sessionKey)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resetMessageToolRound(sessionKey string) {
|
||||
if strings.TrimSpace(sessionKey) == "" {
|
||||
return
|
||||
}
|
||||
if registry := al.GetRegistry(); registry != nil {
|
||||
if agent := registry.GetDefaultAgent(); agent != nil {
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if resetter, ok := tool.(interface{ ResetSentInRound(sessionKey string) }); ok {
|
||||
resetter.ResetSentInRound(sessionKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+238
-30
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -1781,17 +1782,22 @@ func (m *artifactThenSendProvider) Chat(
|
||||
if messages[i].Role != "tool" {
|
||||
continue
|
||||
}
|
||||
start := strings.Index(messages[i].Content, "[file:")
|
||||
if start < 0 {
|
||||
continue
|
||||
for _, prefix := range []string{"[image:", "[file:", "[audio:", "[video:"} {
|
||||
start := strings.Index(messages[i].Content, prefix)
|
||||
if start < 0 {
|
||||
continue
|
||||
}
|
||||
rest := messages[i].Content[start+len(prefix):]
|
||||
end := strings.Index(rest, "]")
|
||||
if end < 0 {
|
||||
continue
|
||||
}
|
||||
artifactPath = rest[:end]
|
||||
break
|
||||
}
|
||||
rest := messages[i].Content[start+len("[file:"):]
|
||||
end := strings.Index(rest, "]")
|
||||
if end < 0 {
|
||||
continue
|
||||
if artifactPath != "" {
|
||||
break
|
||||
}
|
||||
artifactPath = rest[:end]
|
||||
break
|
||||
}
|
||||
if artifactPath == "" {
|
||||
return nil, fmt.Errorf("provider did not receive artifact path in tool result")
|
||||
@@ -4656,7 +4662,7 @@ func TestRun_PicoToolFeedbackSuppressesDuplicateInterimAssistantContent(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
|
||||
func TestResolveMediaRefs_ImageInjectsPathTag(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
@@ -4684,15 +4690,110 @@ func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (images use path tags), got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
|
||||
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
|
||||
localPath, _, _ := store.ResolveWithMeta(ref)
|
||||
expectedContent := "describe this [image:" + localPath + "]"
|
||||
if result[0].Content != expectedContent {
|
||||
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
|
||||
func TestResolveMediaRefs_ToolRoleImageAppendedAsUserMessage(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
pngPath := filepath.Join(dir, "tool-result.png")
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
||||
0x49, 0x48, 0x44, 0x52, // "IHDR"
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
|
||||
0x00, 0x00, 0x00, // no interlace
|
||||
0x90, 0x77, 0x53, 0xDE, // CRC
|
||||
}
|
||||
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "tool", Content: "Image loaded", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
// Tool message should have path tag but no base64
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media in tool message, got %d", len(result[0].Media))
|
||||
}
|
||||
localPath, _, _ := store.ResolveWithMeta(ref)
|
||||
if !strings.Contains(result[0].Content, "[image:"+localPath+"]") {
|
||||
t.Fatalf("expected image path tag in tool content, got %q", result[0].Content)
|
||||
}
|
||||
|
||||
// A synthetic user message with base64 should follow
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 messages (tool + synthetic user), got %d", len(result))
|
||||
}
|
||||
if result[1].Role != "user" {
|
||||
t.Fatalf("expected synthetic message role=user, got %q", result[1].Role)
|
||||
}
|
||||
if len(result[1].Media) != 1 {
|
||||
t.Fatalf("expected 1 base64 media in synthetic user message, got %d", len(result[1].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[1].Media[0], "data:image/png;base64,") {
|
||||
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[1].Media[0][:40])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_MultiToolCallPreservesOrdering(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create image for tool #1
|
||||
pngPath := filepath.Join(dir, "loaded.png")
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
||||
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
||||
}
|
||||
os.WriteFile(pngPath, pngHeader, 0o644)
|
||||
imgRef, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
||||
|
||||
// Simulate: assistant called load_image + read_file, two tool results follow
|
||||
messages := []providers.Message{
|
||||
{Role: "assistant", Content: "Let me load the image and read the file."},
|
||||
{Role: "tool", Content: "Image loaded [image: photo]", Media: []string{imgRef}},
|
||||
{Role: "tool", Content: "file contents here"},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
// assistant, tool#1, tool#2 must remain contiguous — no user in between
|
||||
if result[0].Role != "assistant" {
|
||||
t.Fatalf("result[0] expected assistant, got %q", result[0].Role)
|
||||
}
|
||||
if result[1].Role != "tool" {
|
||||
t.Fatalf("result[1] expected tool, got %q", result[1].Role)
|
||||
}
|
||||
if result[2].Role != "tool" {
|
||||
t.Fatalf("result[2] expected tool, got %q", result[2].Role)
|
||||
}
|
||||
|
||||
// Synthetic user message should come AFTER the tool block
|
||||
if len(result) != 4 {
|
||||
t.Fatalf("expected 4 messages (assistant + 2 tool + synthetic user), got %d", len(result))
|
||||
}
|
||||
if result[3].Role != "user" {
|
||||
t.Fatalf("result[3] expected user, got %q", result[3].Role)
|
||||
}
|
||||
if len(result[3].Media) != 1 || !strings.HasPrefix(result[3].Media[0], "data:image/png;base64,") {
|
||||
t.Fatal("expected synthetic user message to contain base64 image")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_OversizedImageSkipsBase64KeepsPathTag(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
@@ -4714,6 +4815,11 @@ func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
|
||||
}
|
||||
localPath, _, _ := store.ResolveWithMeta(ref)
|
||||
expected := "hi [image:" + localPath + "]"
|
||||
if result[0].Content != expected {
|
||||
t.Fatalf("expected content %q, got %q", expected, result[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_UnknownTypeInjectsPath(t *testing.T) {
|
||||
@@ -4791,11 +4897,13 @@ func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 media, got %d", len(result[0].Media))
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (images use path tags), got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
|
||||
t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
|
||||
localPath, _, _ := store.ResolveWithMeta(ref)
|
||||
expectedContent := "hi [image:" + localPath + "]"
|
||||
if result[0].Content != expectedContent {
|
||||
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4885,6 +4993,98 @@ func TestResolveMediaRefs_NoGenericTagAppendsPath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectPathTags_HandlesVariousChannelPlaceholders(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
content string
|
||||
tag string
|
||||
want string
|
||||
}{
|
||||
// Telegram / Feishu format
|
||||
{"image_photo", "[image: photo]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
|
||||
// WeCom / WeChat / Line format
|
||||
{"bare_image", "[image]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
|
||||
// QQ / Discord format with filename
|
||||
{"image_filename", "[image: pic.jpg]", "[image:/tmp/p.png]", "[image:/tmp/p.png]"},
|
||||
{"audio_with_filename", "[audio: voice.m4a]", "[audio:/tmp/a.m4a]", "[audio:/tmp/a.m4a]"},
|
||||
{"bare_audio", "[audio]", "[audio:/tmp/a.m4a]", "[audio:/tmp/a.m4a]"},
|
||||
{"bare_video", "[video]", "[video:/tmp/v.mp4]", "[video:/tmp/v.mp4]"},
|
||||
{"bare_file", "[file]", "[file:/tmp/f.pdf]", "[file:/tmp/f.pdf]"},
|
||||
// Mixed surrounding text
|
||||
{
|
||||
"with_text",
|
||||
"hello [image] world",
|
||||
"[image:/tmp/p.png]",
|
||||
"hello [image:/tmp/p.png] world",
|
||||
},
|
||||
// No placeholder — append
|
||||
{"no_placeholder", "hello world", "[image:/tmp/p.png]", "hello world [image:/tmp/p.png]"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := injectPathTags(tc.content, []string{tc.tag})
|
||||
if got != tc.want {
|
||||
t.Errorf("expected %q, got %q", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectPathTags_DoesNotReplacePathTag(t *testing.T) {
|
||||
// If content already contains a path tag, we must not touch it.
|
||||
content := "see [image:/already/placed.png] thanks"
|
||||
got := injectPathTags(content, []string{"[image:/new/path.png]"})
|
||||
want := "see [image:/already/placed.png] thanks [image:/new/path.png]"
|
||||
if got != want {
|
||||
t.Fatalf("expected %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectPathTags_PrependsForJSONContent(t *testing.T) {
|
||||
jsonContent := `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`
|
||||
got := injectPathTags(jsonContent, []string{"[image:/tmp/photo.png]"})
|
||||
want := "[image:/tmp/photo.png]\n" + jsonContent
|
||||
if got != want {
|
||||
t.Fatalf("expected tag prepended to JSON, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectPathTags_BracketTextNotTreatedAsJSON(t *testing.T) {
|
||||
content := "[update] see attached report"
|
||||
got := injectPathTags(content, []string{"[file:/tmp/report.pdf]"})
|
||||
want := "[update] see attached report [file:/tmp/report.pdf]"
|
||||
if got != want {
|
||||
t.Fatalf("expected tag appended to bracket text, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_JSONContentPrependsPathTag(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
pngPath := filepath.Join(dir, "card_img.png")
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
||||
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
||||
}
|
||||
os.WriteFile(pngPath, pngHeader, 0o644)
|
||||
ref, _ := store.Store(pngPath, media.MediaMeta{ContentType: "image/png"}, "test")
|
||||
|
||||
jsonContent := `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: jsonContent, Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
want := "[image:" + pngPath + "]\n" + jsonContent
|
||||
if result[0].Content != want {
|
||||
t.Fatalf("expected path tag prepended to JSON content, got %q", result[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_EmptyContentGetsPathTag(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
@@ -4928,13 +5128,12 @@ func TestResolveMediaRefs_MixedImageAndFile(t *testing.T) {
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 media (image only), got %d", len(result[0].Media))
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (all types use path tags), got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
|
||||
t.Fatal("expected image to be base64 encoded")
|
||||
}
|
||||
expectedContent := "check these [file:" + pdfPath + "]"
|
||||
imgLocalPath, _, _ := store.ResolveWithMeta(imgRef)
|
||||
pdfLocalPath, _, _ := store.ResolveWithMeta(fileRef)
|
||||
expectedContent := "check these [file:" + pdfLocalPath + "] [image:" + imgLocalPath + "]"
|
||||
if result[0].Content != expectedContent {
|
||||
t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content)
|
||||
}
|
||||
@@ -5258,6 +5457,7 @@ func TestParallelMessageProcessing_SameSessionProcessedSequentially(t *testing.T
|
||||
var mu sync.Mutex
|
||||
turnIDs := make(map[string]bool)
|
||||
var wg sync.WaitGroup
|
||||
var firstResponse sync.Once
|
||||
wg.Add(1) // Only 1 turn should be created for same session
|
||||
|
||||
cfg := &config.Config{
|
||||
@@ -5280,19 +5480,27 @@ func TestParallelMessageProcessing_SameSessionProcessedSequentially(t *testing.T
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, &concurrentMockProvider{
|
||||
responseFunc: func(callID int) string {
|
||||
wg.Done()
|
||||
firstResponse.Do(func() {
|
||||
wg.Done()
|
||||
})
|
||||
return "ok"
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
sub := al.SubscribeEvents(64)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
64,
|
||||
runtimeevents.KindAgentTurnStart,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
go func() {
|
||||
for evt := range sub.C {
|
||||
if evt.Kind == EventKindTurnStart {
|
||||
for evt := range runtimeCh {
|
||||
if evt.Kind == runtimeevents.KindAgentTurnStart {
|
||||
mu.Lock()
|
||||
turnIDs[evt.Meta.TurnID] = true
|
||||
turnIDs[evt.Scope.TurnID] = true
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"path/filepath"
|
||||
@@ -171,15 +170,8 @@ func toolFeedbackExplanationFromMessages(messages []providers.Message) string {
|
||||
}
|
||||
|
||||
func toolFeedbackArgsPreview(args map[string]any, maxLen int) string {
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
|
||||
argsJSON, err := json.MarshalIndent(args, "", " ")
|
||||
if err != nil {
|
||||
return utils.Truncate(fmt.Sprintf("%v", args), maxLen)
|
||||
}
|
||||
return utils.Truncate(string(argsJSON), maxLen)
|
||||
argsJSON := utils.FormatArgsJSON(args, true, false)
|
||||
return utils.Truncate(argsJSON, maxLen)
|
||||
}
|
||||
|
||||
func shouldPublishToolFeedback(cfg *config.Config, ts *turnState) bool {
|
||||
@@ -293,6 +285,12 @@ func inferMediaType(filename, contentType string) string {
|
||||
ct := strings.ToLower(contentType)
|
||||
fn := strings.ToLower(filename)
|
||||
|
||||
// SVG is an image MIME type, but raster-only delivery endpoints such as
|
||||
// Telegram SendPhoto reject it. Treat it as a file/document instead.
|
||||
if strings.HasPrefix(ct, "image/svg") || filepath.Ext(fn) == ".svg" {
|
||||
return "file"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(ct, "image/") {
|
||||
return "image"
|
||||
}
|
||||
@@ -306,7 +304,7 @@ func inferMediaType(filename, contentType string) string {
|
||||
// Fallback: infer from extension
|
||||
ext := filepath.Ext(fn)
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg":
|
||||
case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp":
|
||||
return "image"
|
||||
case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus":
|
||||
return "audio"
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestInferMediaType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
contentType string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "png content type",
|
||||
filename: "diagram",
|
||||
contentType: "image/png",
|
||||
want: "image",
|
||||
},
|
||||
{
|
||||
name: "jpeg extension fallback",
|
||||
filename: "photo.JPG",
|
||||
contentType: "",
|
||||
want: "image",
|
||||
},
|
||||
{
|
||||
name: "svg content type is file",
|
||||
filename: "diagram",
|
||||
contentType: "image/svg+xml",
|
||||
want: "file",
|
||||
},
|
||||
{
|
||||
name: "svg content type with parameters is file",
|
||||
filename: "diagram",
|
||||
contentType: "image/svg+xml; charset=utf-8",
|
||||
want: "file",
|
||||
},
|
||||
{
|
||||
name: "svg extension fallback is file",
|
||||
filename: "diagram.SVG",
|
||||
contentType: "",
|
||||
want: "file",
|
||||
},
|
||||
{
|
||||
name: "audio content type",
|
||||
filename: "voice",
|
||||
contentType: "audio/ogg",
|
||||
want: "audio",
|
||||
},
|
||||
{
|
||||
name: "ogg application content type",
|
||||
filename: "voice.ogg",
|
||||
contentType: "application/ogg",
|
||||
want: "audio",
|
||||
},
|
||||
{
|
||||
name: "video extension fallback",
|
||||
filename: "clip.MP4",
|
||||
contentType: "",
|
||||
want: "video",
|
||||
},
|
||||
{
|
||||
name: "unknown type",
|
||||
filename: "archive.bin",
|
||||
contentType: "application/octet-stream",
|
||||
want: "file",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := inferMediaType(tt.filename, tt.contentType)
|
||||
if got != tt.want {
|
||||
t.Fatalf("inferMediaType(%q, %q) = %q, want %q", tt.filename, tt.contentType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
@@ -41,7 +42,7 @@ func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) e
|
||||
// Sync emergency compression — budget exceeded.
|
||||
if result, ok := m.forceCompression(req.SessionKey); ok {
|
||||
m.al.emitEvent(
|
||||
EventKindContextCompress,
|
||||
runtimeevents.KindAgentContextCompress,
|
||||
m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
|
||||
ContextCompressPayload{
|
||||
Reason: req.Reason,
|
||||
@@ -246,7 +247,7 @@ func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey
|
||||
agent.Sessions.TruncateHistory(sessionKey, keepCount)
|
||||
agent.Sessions.Save(sessionKey)
|
||||
m.al.emitEvent(
|
||||
EventKindSessionSummarize,
|
||||
runtimeevents.KindAgentSessionSummarize,
|
||||
m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
|
||||
SessionSummarizePayload{
|
||||
SummarizedMessages: len(validMessages),
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -305,8 +306,13 @@ func TestLegacyCompact_Overflow(t *testing.T) {
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-overflow", history)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentContextCompress,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-overflow",
|
||||
@@ -329,8 +335,8 @@ func TestLegacyCompact_Overflow(t *testing.T) {
|
||||
}
|
||||
|
||||
// Event should carry the proactive reason
|
||||
events := collectEventStream(sub.C)
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
compressEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
@@ -361,8 +367,13 @@ func TestLegacyCompact_Overflow_ProactiveReason(t *testing.T) {
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-proactive", history)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentContextCompress,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-proactive",
|
||||
@@ -372,8 +383,8 @@ func TestLegacyCompact_Overflow_ProactiveReason(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
compressEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
@@ -483,6 +494,14 @@ func TestLegacyCompact_PostTurn_ExceedsMessageThreshold(t *testing.T) {
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory("session-threshold", history)
|
||||
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentSessionSummarize,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
err := al.contextManager.Compact(context.Background(), &CompactRequest{
|
||||
SessionKey: "session-threshold",
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
@@ -491,12 +510,8 @@ func TestLegacyCompact_PostTurn_ExceedsMessageThreshold(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Wait for async summarization to complete via event
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
waitForEvent(t, sub.C, 5*time.Second, func(evt Event) bool {
|
||||
return evt.Kind == EventKindSessionSummarize
|
||||
waitForRuntimeEvent(t, runtimeCh, 5*time.Second, func(evt runtimeevents.Event) bool {
|
||||
return evt.Kind == runtimeevents.KindAgentSessionSummarize
|
||||
})
|
||||
|
||||
newHistory := defaultAgent.Sessions.GetHistory("session-threshold")
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
package agent
|
||||
|
||||
import "time"
|
||||
|
||||
// TurnEndStatus describes the terminal state of a turn.
|
||||
type TurnEndStatus string
|
||||
|
||||
const (
|
||||
// TurnEndStatusCompleted indicates the turn finished normally.
|
||||
TurnEndStatusCompleted TurnEndStatus = "completed"
|
||||
// TurnEndStatusError indicates the turn ended because of an error.
|
||||
TurnEndStatusError TurnEndStatus = "error"
|
||||
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
|
||||
TurnEndStatusAborted TurnEndStatus = "aborted"
|
||||
)
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
type TurnStartPayload struct {
|
||||
UserMessage string
|
||||
MediaCount int
|
||||
}
|
||||
|
||||
// TurnEndPayload describes the completion of a turn.
|
||||
type TurnEndPayload struct {
|
||||
Status TurnEndStatus
|
||||
Iterations int
|
||||
Duration time.Duration
|
||||
FinalContentLen int
|
||||
}
|
||||
|
||||
// LLMRequestPayload describes an outbound LLM request.
|
||||
type LLMRequestPayload struct {
|
||||
Model string
|
||||
MessagesCount int
|
||||
ToolsCount int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
}
|
||||
|
||||
// LLMResponsePayload describes an inbound LLM response.
|
||||
type LLMResponsePayload struct {
|
||||
ContentLen int
|
||||
ToolCalls int
|
||||
HasReasoning bool
|
||||
}
|
||||
|
||||
// LLMDeltaPayload describes a streamed LLM delta.
|
||||
type LLMDeltaPayload struct {
|
||||
ContentDeltaLen int
|
||||
ReasoningDeltaLen int
|
||||
}
|
||||
|
||||
// LLMRetryPayload describes a retry of an LLM request.
|
||||
type LLMRetryPayload struct {
|
||||
Attempt int
|
||||
MaxRetries int
|
||||
Reason string
|
||||
Error string
|
||||
Backoff time.Duration
|
||||
}
|
||||
|
||||
// ContextCompressReason identifies why emergency compression ran.
|
||||
type ContextCompressReason string
|
||||
|
||||
const (
|
||||
// ContextCompressReasonProactive indicates compression before the first LLM call.
|
||||
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
|
||||
// ContextCompressReasonRetry indicates compression during context-error retry handling.
|
||||
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
|
||||
// ContextCompressReasonSummarize indicates post-turn async summarization.
|
||||
ContextCompressReasonSummarize ContextCompressReason = "summarize"
|
||||
)
|
||||
|
||||
// ContextCompressPayload describes a forced history compression.
|
||||
type ContextCompressPayload struct {
|
||||
Reason ContextCompressReason
|
||||
DroppedMessages int
|
||||
RemainingMessages int
|
||||
}
|
||||
|
||||
// SessionSummarizePayload describes a completed async session summarization.
|
||||
type SessionSummarizePayload struct {
|
||||
SummarizedMessages int
|
||||
KeptMessages int
|
||||
SummaryLen int
|
||||
OmittedOversized bool
|
||||
}
|
||||
|
||||
// ToolExecStartPayload describes a tool execution request.
|
||||
type ToolExecStartPayload struct {
|
||||
Tool string
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// ToolExecEndPayload describes the outcome of a tool execution.
|
||||
type ToolExecEndPayload struct {
|
||||
Tool string
|
||||
Duration time.Duration
|
||||
ForLLMLen int
|
||||
ForUserLen int
|
||||
IsError bool
|
||||
Async bool
|
||||
}
|
||||
|
||||
// ToolExecSkippedPayload describes a skipped tool call.
|
||||
type ToolExecSkippedPayload struct {
|
||||
Tool string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// SteeringInjectedPayload describes steering messages appended before the next LLM call.
|
||||
type SteeringInjectedPayload struct {
|
||||
Count int
|
||||
TotalContentLen int
|
||||
}
|
||||
|
||||
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
|
||||
type FollowUpQueuedPayload struct {
|
||||
SourceTool string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
type InterruptKind string
|
||||
|
||||
const (
|
||||
InterruptKindSteering InterruptKind = "steering"
|
||||
InterruptKindGraceful InterruptKind = "graceful"
|
||||
InterruptKindHard InterruptKind = "hard_abort"
|
||||
)
|
||||
|
||||
// InterruptReceivedPayload describes accepted turn-control input.
|
||||
type InterruptReceivedPayload struct {
|
||||
Kind InterruptKind
|
||||
Role string
|
||||
ContentLen int
|
||||
QueueDepth int
|
||||
HintLen int
|
||||
}
|
||||
|
||||
// SubTurnSpawnPayload describes the creation of a child turn.
|
||||
type SubTurnSpawnPayload struct {
|
||||
AgentID string
|
||||
Label string
|
||||
ParentTurnID string
|
||||
}
|
||||
|
||||
// SubTurnEndPayload describes the completion of a child turn.
|
||||
type SubTurnEndPayload struct {
|
||||
AgentID string
|
||||
Status string
|
||||
}
|
||||
|
||||
// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
|
||||
type SubTurnResultDeliveredPayload struct {
|
||||
TargetChannel string
|
||||
TargetChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
// SubTurnOrphanPayload describes a sub-turn result that could not be delivered.
|
||||
type SubTurnOrphanPayload struct {
|
||||
ParentTurnID string
|
||||
ChildTurnID string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// ErrorPayload describes an execution error inside the agent loop.
|
||||
type ErrorPayload struct {
|
||||
Stage string
|
||||
Message string
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultEventSubscriberBuffer = 16
|
||||
|
||||
// EventSubscription identifies a subscriber channel returned by EventBus.Subscribe.
|
||||
type EventSubscription struct {
|
||||
ID uint64
|
||||
C <-chan Event
|
||||
}
|
||||
|
||||
type eventSubscriber struct {
|
||||
ch chan Event
|
||||
}
|
||||
|
||||
// EventBus is a lightweight multi-subscriber broadcaster for agent-loop events.
|
||||
type EventBus struct {
|
||||
mu sync.RWMutex
|
||||
subs map[uint64]eventSubscriber
|
||||
nextID uint64
|
||||
closed bool
|
||||
dropped [eventKindCount]atomic.Int64
|
||||
}
|
||||
|
||||
// NewEventBus creates a new in-process event broadcaster.
|
||||
func NewEventBus() *EventBus {
|
||||
return &EventBus{
|
||||
subs: make(map[uint64]eventSubscriber),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe registers a new subscriber with the requested channel buffer size.
|
||||
// A non-positive buffer uses the default size.
|
||||
func (b *EventBus) Subscribe(buffer int) EventSubscription {
|
||||
if buffer <= 0 {
|
||||
buffer = defaultEventSubscriberBuffer
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
ch := make(chan Event)
|
||||
close(ch)
|
||||
return EventSubscription{C: ch}
|
||||
}
|
||||
|
||||
b.nextID++
|
||||
id := b.nextID
|
||||
ch := make(chan Event, buffer)
|
||||
b.subs[id] = eventSubscriber{ch: ch}
|
||||
return EventSubscription{ID: id, C: ch}
|
||||
}
|
||||
|
||||
// Unsubscribe removes a subscriber and closes its channel.
|
||||
func (b *EventBus) Unsubscribe(id uint64) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
sub, ok := b.subs[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(b.subs, id)
|
||||
close(sub.ch)
|
||||
}
|
||||
|
||||
// Emit broadcasts an event to all current subscribers without blocking.
|
||||
// When a subscriber channel is full, the event is dropped for that subscriber.
|
||||
func (b *EventBus) Emit(evt Event) {
|
||||
if evt.Time.IsZero() {
|
||||
evt.Time = time.Now()
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
|
||||
for _, sub := range b.subs {
|
||||
select {
|
||||
case sub.ch <- evt:
|
||||
default:
|
||||
if evt.Kind < eventKindCount {
|
||||
b.dropped[evt.Kind].Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dropped returns the number of dropped events for a given kind.
|
||||
func (b *EventBus) Dropped(kind EventKind) int64 {
|
||||
if kind >= eventKindCount {
|
||||
return 0
|
||||
}
|
||||
return b.dropped[kind].Load()
|
||||
}
|
||||
|
||||
// Close closes all subscriber channels and stops future broadcasts.
|
||||
func (b *EventBus) Close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
|
||||
b.closed = true
|
||||
for id, sub := range b.subs {
|
||||
close(sub.ch)
|
||||
delete(b.subs, id)
|
||||
}
|
||||
}
|
||||
+155
-132
@@ -9,61 +9,94 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func TestEventBus_SubscribeEmitUnsubscribeClose(t *testing.T) {
|
||||
eventBus := NewEventBus()
|
||||
sub := eventBus.Subscribe(1)
|
||||
|
||||
eventBus.Emit(Event{
|
||||
Kind: EventKindTurnStart,
|
||||
Meta: EventMeta{TurnID: "turn-1"},
|
||||
})
|
||||
|
||||
select {
|
||||
case evt := <-sub.C:
|
||||
if evt.Kind != EventKindTurnStart {
|
||||
t.Fatalf("expected %v, got %v", EventKindTurnStart, evt.Kind)
|
||||
func TestAgentLoop_PublishesRuntimeEvents(t *testing.T) {
|
||||
runtimeBus := runtimeevents.NewBus()
|
||||
al := &AgentLoop{
|
||||
runtimeEvents: runtimeBus,
|
||||
}
|
||||
defer func() {
|
||||
if err := runtimeBus.Close(); err != nil {
|
||||
t.Errorf("runtime bus close failed: %v", err)
|
||||
}
|
||||
if evt.Meta.TurnID != "turn-1" {
|
||||
t.Fatalf("expected turn id turn-1, got %q", evt.Meta.TurnID)
|
||||
}()
|
||||
|
||||
runtimeSub, runtimeCh, err := al.RuntimeEvents().OfKind(runtimeevents.KindAgentToolExecStart).SubscribeChan(
|
||||
context.Background(),
|
||||
runtimeevents.SubscribeOptions{Name: "runtime", Buffer: 1},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := runtimeSub.Close(); err != nil {
|
||||
t.Errorf("runtime subscription close failed: %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
}()
|
||||
|
||||
al.emitEvent(
|
||||
runtimeevents.KindAgentToolExecStart,
|
||||
HookMeta{
|
||||
AgentID: "main",
|
||||
TurnID: "turn-1",
|
||||
ParentTurnID: "parent-turn",
|
||||
SessionKey: "session-1",
|
||||
Iteration: 2,
|
||||
TracePath: "trace/root",
|
||||
Source: "pipeline_execute",
|
||||
turnContext: &TurnContext{
|
||||
Inbound: &bus.InboundContext{
|
||||
Channel: "cli",
|
||||
Account: "default",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "tester",
|
||||
MessageID: "msg-1",
|
||||
TopicID: "topic-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
ToolExecStartPayload{Tool: "mock_custom", Arguments: map[string]any{"task": "ping"}},
|
||||
)
|
||||
|
||||
runtimeEvt := receiveRuntimeEvent(t, runtimeCh)
|
||||
if runtimeEvt.Kind != runtimeevents.KindAgentToolExecStart {
|
||||
t.Fatalf("runtime kind = %q, want %q", runtimeEvt.Kind, runtimeevents.KindAgentToolExecStart)
|
||||
}
|
||||
|
||||
eventBus.Unsubscribe(sub.ID)
|
||||
if _, ok := <-sub.C; ok {
|
||||
t.Fatal("expected subscriber channel to be closed after unsubscribe")
|
||||
if runtimeEvt.Source != (runtimeevents.Source{Component: "agent", Name: "main"}) {
|
||||
t.Fatalf("runtime source = %+v", runtimeEvt.Source)
|
||||
}
|
||||
|
||||
eventBus.Close()
|
||||
closedSub := eventBus.Subscribe(1)
|
||||
if _, ok := <-closedSub.C; ok {
|
||||
t.Fatal("expected closed bus to return a closed subscriber channel")
|
||||
if runtimeEvt.Scope.AgentID != "main" ||
|
||||
runtimeEvt.Scope.SessionKey != "session-1" ||
|
||||
runtimeEvt.Scope.TurnID != "turn-1" ||
|
||||
runtimeEvt.Scope.Channel != "cli" ||
|
||||
runtimeEvt.Scope.Account != "default" ||
|
||||
runtimeEvt.Scope.ChatID != "direct" ||
|
||||
runtimeEvt.Scope.TopicID != "topic-1" ||
|
||||
runtimeEvt.Scope.ChatType != "direct" ||
|
||||
runtimeEvt.Scope.SenderID != "tester" ||
|
||||
runtimeEvt.Scope.MessageID != "msg-1" {
|
||||
t.Fatalf("runtime scope = %+v", runtimeEvt.Scope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBus_DropsWhenSubscriberIsFull(t *testing.T) {
|
||||
eventBus := NewEventBus()
|
||||
sub := eventBus.Subscribe(1)
|
||||
defer eventBus.Unsubscribe(sub.ID)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < 1000; i++ {
|
||||
eventBus.Emit(Event{Kind: EventKindLLMRequest})
|
||||
if runtimeEvt.Correlation.TraceID != "trace/root" ||
|
||||
runtimeEvt.Correlation.ParentTurnID != "parent-turn" {
|
||||
t.Fatalf("runtime correlation = %+v", runtimeEvt.Correlation)
|
||||
}
|
||||
|
||||
if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
|
||||
t.Fatalf("Emit took too long with a blocked subscriber: %s", elapsed)
|
||||
if runtimeEvt.Attrs["agent_source"] != "pipeline_execute" || runtimeEvt.Attrs["iteration"] != 2 {
|
||||
t.Fatalf("runtime attrs = %+v", runtimeEvt.Attrs)
|
||||
}
|
||||
|
||||
if got := eventBus.Dropped(EventKindLLMRequest); got != 999 {
|
||||
t.Fatalf("expected 999 dropped events, got %d", got)
|
||||
payload, ok := runtimeEvt.Payload.(ToolExecStartPayload)
|
||||
if !ok {
|
||||
t.Fatalf("runtime payload = %T, want ToolExecStartPayload", runtimeEvt.Payload)
|
||||
}
|
||||
if payload.Tool != "mock_custom" {
|
||||
t.Fatalf("runtime payload tool = %q, want mock_custom", payload.Tool)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,8 +160,18 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
expectedKinds := []runtimeevents.Kind{
|
||||
runtimeevents.KindAgentTurnStart,
|
||||
runtimeevents.KindAgentLLMRequest,
|
||||
runtimeevents.KindAgentLLMResponse,
|
||||
runtimeevents.KindAgentToolExecStart,
|
||||
runtimeevents.KindAgentToolExecEnd,
|
||||
runtimeevents.KindAgentLLMRequest,
|
||||
runtimeevents.KindAgentLLMResponse,
|
||||
runtimeevents.KindAgentTurnEnd,
|
||||
}
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(t, al, 16, expectedKinds...)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
@@ -171,49 +214,36 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
||||
t.Fatalf("expected final response 'done', got %q", response)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
if len(events) != 8 {
|
||||
t.Fatalf("expected 8 events, got %d", len(events))
|
||||
}
|
||||
|
||||
kinds := make([]EventKind, 0, len(events))
|
||||
kinds := make([]runtimeevents.Kind, 0, len(events))
|
||||
for _, evt := range events {
|
||||
kinds = append(kinds, evt.Kind)
|
||||
}
|
||||
|
||||
expectedKinds := []EventKind{
|
||||
EventKindTurnStart,
|
||||
EventKindLLMRequest,
|
||||
EventKindLLMResponse,
|
||||
EventKindToolExecStart,
|
||||
EventKindToolExecEnd,
|
||||
EventKindLLMRequest,
|
||||
EventKindLLMResponse,
|
||||
EventKindTurnEnd,
|
||||
}
|
||||
if !slices.Equal(kinds, expectedKinds) {
|
||||
t.Fatalf("unexpected event sequence: got %v want %v", kinds, expectedKinds)
|
||||
}
|
||||
|
||||
turnID := events[0].Meta.TurnID
|
||||
turnID := events[0].Scope.TurnID
|
||||
if turnID == "" {
|
||||
t.Fatal("expected runtime events to include turn id")
|
||||
}
|
||||
for i, evt := range events {
|
||||
if evt.Meta.TurnID != turnID {
|
||||
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Meta.TurnID, turnID)
|
||||
if evt.Scope.TurnID != turnID {
|
||||
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Scope.TurnID, turnID)
|
||||
}
|
||||
if evt.Meta.SessionKey != "session-1" {
|
||||
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
|
||||
if evt.Scope.SessionKey != "session-1" {
|
||||
t.Fatalf("event %d has session key %q, want session-1", i, evt.Scope.SessionKey)
|
||||
}
|
||||
if evt.Context == nil || evt.Context.Inbound == nil {
|
||||
t.Fatalf("event %d missing inbound turn context", i)
|
||||
if evt.Scope.Channel != "cli" || evt.Scope.ChatID != "direct" || evt.Scope.SenderID != "tester" {
|
||||
t.Fatalf("event %d scope = %+v", i, evt.Scope)
|
||||
}
|
||||
if evt.Context.Inbound.Channel != "cli" || evt.Context.Inbound.SenderID != "tester" {
|
||||
t.Fatalf("event %d inbound context = %+v", i, evt.Context.Inbound)
|
||||
}
|
||||
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
|
||||
t.Fatalf("event %d missing route context: %+v", i, evt.Context.Route)
|
||||
}
|
||||
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "tester" {
|
||||
t.Fatalf("event %d missing session scope: %+v", i, evt.Context.Scope)
|
||||
if evt.Scope.AgentID != "main" {
|
||||
t.Fatalf("event %d has agent id %q, want main", i, evt.Scope.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,8 +339,15 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
32,
|
||||
runtimeevents.KindAgentSteeringInjected,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
runtimeevents.KindAgentInterruptReceived,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
resultCh := make(chan string, 1)
|
||||
go func() {
|
||||
@@ -337,8 +374,8 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
t.Fatal("timeout waiting for steered response")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
steeringEvt, ok := findEvent(events, EventKindSteeringInjected)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
steeringEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentSteeringInjected)
|
||||
if !ok {
|
||||
t.Fatal("expected steering injected event")
|
||||
}
|
||||
@@ -350,7 +387,7 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
t.Fatalf("expected 1 steering message, got %d", steeringPayload.Count)
|
||||
}
|
||||
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
skippedEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected skipped tool event")
|
||||
}
|
||||
@@ -362,7 +399,7 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
t.Fatalf("expected skipped tool_two, got %q", skippedPayload.Tool)
|
||||
}
|
||||
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
interruptEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
@@ -420,8 +457,14 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
|
||||
{Role: "user", Content: "Trigger message"},
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentLLMRetry,
|
||||
runtimeevents.KindAgentContextCompress,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
@@ -439,8 +482,8 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
|
||||
t.Fatalf("expected retry success, got %q", resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
retryEvt, ok := findEvent(events, EventKindLLMRetry)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
retryEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentLLMRetry)
|
||||
if !ok {
|
||||
t.Fatal("expected llm retry event")
|
||||
}
|
||||
@@ -455,7 +498,7 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
|
||||
t.Fatalf("expected retry attempt 1, got %d", retryPayload.Attempt)
|
||||
}
|
||||
|
||||
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
||||
compressEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentContextCompress)
|
||||
if !ok {
|
||||
t.Fatal("expected context compress event")
|
||||
}
|
||||
@@ -508,14 +551,19 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
|
||||
{Role: "assistant", Content: "Answer three"},
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentSessionSummarize,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
lcm := &legacyContextManager{al: al}
|
||||
lcm.summarizeSession(defaultAgent, "session-1")
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
summaryEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentSessionSummarize)
|
||||
if !ok {
|
||||
t.Fatal("expected session summarize event")
|
||||
}
|
||||
@@ -575,8 +623,13 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
32,
|
||||
runtimeevents.KindAgentFollowUpQueued,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
@@ -600,8 +653,8 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
||||
t.Fatal("timeout waiting for async tool completion")
|
||||
}
|
||||
|
||||
followUpEvt := waitForEvent(t, sub.C, 2*time.Second, func(evt Event) bool {
|
||||
return evt.Kind == EventKindFollowUpQueued
|
||||
followUpEvt := waitForRuntimeEvent(t, runtimeCh, 2*time.Second, func(evt runtimeevents.Event) bool {
|
||||
return evt.Kind == runtimeevents.KindAgentFollowUpQueued
|
||||
})
|
||||
payload, ok := followUpEvt.Payload.(FollowUpQueuedPayload)
|
||||
if !ok {
|
||||
@@ -613,59 +666,29 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
||||
if payload.ContentLen != len("background result") {
|
||||
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
|
||||
}
|
||||
if followUpEvt.Meta.SessionKey != "session-1" {
|
||||
t.Fatalf("expected session key session-1, got %q", followUpEvt.Meta.SessionKey)
|
||||
if followUpEvt.Scope.SessionKey != "session-1" {
|
||||
t.Fatalf("expected session key session-1, got %q", followUpEvt.Scope.SessionKey)
|
||||
}
|
||||
if followUpEvt.Meta.TurnID == "" {
|
||||
if followUpEvt.Scope.TurnID == "" {
|
||||
t.Fatal("expected follow-up event to include turn id")
|
||||
}
|
||||
}
|
||||
|
||||
func collectEventStream(ch <-chan Event) []Event {
|
||||
var events []Event
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
return events
|
||||
}
|
||||
events = append(events, evt)
|
||||
default:
|
||||
return events
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
|
||||
func receiveRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("event stream closed before expected event arrived")
|
||||
}
|
||||
if match(evt) {
|
||||
return evt
|
||||
}
|
||||
case <-timer.C:
|
||||
t.Fatal("timed out waiting for expected event")
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event stream closed before expected event arrived")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
func findEvent(events []Event, kind EventKind) (Event, bool) {
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
return evt, true
|
||||
}
|
||||
}
|
||||
return Event{}, false
|
||||
}
|
||||
|
||||
type stringError string
|
||||
|
||||
func (e stringError) Error() string {
|
||||
|
||||
+3
-260
@@ -1,97 +1,8 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EventKind identifies a structured agent-loop event.
|
||||
type EventKind uint8
|
||||
|
||||
const (
|
||||
// EventKindTurnStart is emitted when a turn begins processing.
|
||||
EventKindTurnStart EventKind = iota
|
||||
// EventKindTurnEnd is emitted when a turn finishes, successfully or with an error.
|
||||
EventKindTurnEnd
|
||||
// EventKindLLMRequest is emitted before a provider chat request is made.
|
||||
EventKindLLMRequest
|
||||
// EventKindLLMDelta is emitted when a streaming provider yields a partial delta.
|
||||
EventKindLLMDelta
|
||||
// EventKindLLMResponse is emitted after a provider chat response is received.
|
||||
EventKindLLMResponse
|
||||
// EventKindLLMRetry is emitted when an LLM request is retried.
|
||||
EventKindLLMRetry
|
||||
// EventKindContextCompress is emitted when session history is forcibly compressed.
|
||||
EventKindContextCompress
|
||||
// EventKindSessionSummarize is emitted when asynchronous summarization completes.
|
||||
EventKindSessionSummarize
|
||||
// EventKindToolExecStart is emitted immediately before a tool executes.
|
||||
EventKindToolExecStart
|
||||
// EventKindToolExecEnd is emitted immediately after a tool finishes executing.
|
||||
EventKindToolExecEnd
|
||||
// EventKindToolExecSkipped is emitted when a queued tool call is skipped.
|
||||
EventKindToolExecSkipped
|
||||
// EventKindSteeringInjected is emitted when queued steering is injected into context.
|
||||
EventKindSteeringInjected
|
||||
// EventKindFollowUpQueued is emitted when an async tool queues a follow-up system message.
|
||||
EventKindFollowUpQueued
|
||||
// EventKindInterruptReceived is emitted when a soft interrupt message is accepted.
|
||||
EventKindInterruptReceived
|
||||
// EventKindSubTurnSpawn is emitted when a sub-turn is spawned.
|
||||
EventKindSubTurnSpawn
|
||||
// EventKindSubTurnEnd is emitted when a sub-turn finishes.
|
||||
EventKindSubTurnEnd
|
||||
// EventKindSubTurnResultDelivered is emitted when a sub-turn result is delivered.
|
||||
EventKindSubTurnResultDelivered
|
||||
// EventKindSubTurnOrphan is emitted when a sub-turn result cannot be delivered.
|
||||
EventKindSubTurnOrphan
|
||||
// EventKindError is emitted when a turn encounters an execution error.
|
||||
EventKindError
|
||||
|
||||
eventKindCount
|
||||
)
|
||||
|
||||
var eventKindNames = [...]string{
|
||||
"turn_start",
|
||||
"turn_end",
|
||||
"llm_request",
|
||||
"llm_delta",
|
||||
"llm_response",
|
||||
"llm_retry",
|
||||
"context_compress",
|
||||
"session_summarize",
|
||||
"tool_exec_start",
|
||||
"tool_exec_end",
|
||||
"tool_exec_skipped",
|
||||
"steering_injected",
|
||||
"follow_up_queued",
|
||||
"interrupt_received",
|
||||
"subturn_spawn",
|
||||
"subturn_end",
|
||||
"subturn_result_delivered",
|
||||
"subturn_orphan",
|
||||
"error",
|
||||
}
|
||||
|
||||
// String returns the stable string form of an EventKind.
|
||||
func (k EventKind) String() string {
|
||||
if k >= eventKindCount {
|
||||
return fmt.Sprintf("event_kind(%d)", k)
|
||||
}
|
||||
return eventKindNames[k]
|
||||
}
|
||||
|
||||
// Event is the structured envelope broadcast by the agent EventBus.
|
||||
type Event struct {
|
||||
Kind EventKind
|
||||
Time time.Time
|
||||
Meta EventMeta
|
||||
Context *TurnContext
|
||||
Payload any
|
||||
}
|
||||
|
||||
// EventMeta contains correlation fields shared by all agent-loop events.
|
||||
type EventMeta struct {
|
||||
// HookMeta contains correlation fields shared by agent hook requests and
|
||||
// runtime events emitted from turn processing.
|
||||
type HookMeta struct {
|
||||
AgentID string
|
||||
TurnID string
|
||||
ParentTurnID string
|
||||
@@ -101,171 +12,3 @@ type EventMeta struct {
|
||||
Source string
|
||||
turnContext *TurnContext
|
||||
}
|
||||
|
||||
// TurnEndStatus describes the terminal state of a turn.
|
||||
type TurnEndStatus string
|
||||
|
||||
const (
|
||||
// TurnEndStatusCompleted indicates the turn finished normally.
|
||||
TurnEndStatusCompleted TurnEndStatus = "completed"
|
||||
// TurnEndStatusError indicates the turn ended because of an error.
|
||||
TurnEndStatusError TurnEndStatus = "error"
|
||||
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
|
||||
TurnEndStatusAborted TurnEndStatus = "aborted"
|
||||
)
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
type TurnStartPayload struct {
|
||||
UserMessage string
|
||||
MediaCount int
|
||||
}
|
||||
|
||||
// TurnEndPayload describes the completion of a turn.
|
||||
type TurnEndPayload struct {
|
||||
Status TurnEndStatus
|
||||
Iterations int
|
||||
Duration time.Duration
|
||||
FinalContentLen int
|
||||
}
|
||||
|
||||
// LLMRequestPayload describes an outbound LLM request.
|
||||
type LLMRequestPayload struct {
|
||||
Model string
|
||||
MessagesCount int
|
||||
ToolsCount int
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
}
|
||||
|
||||
// LLMResponsePayload describes an inbound LLM response.
|
||||
type LLMResponsePayload struct {
|
||||
ContentLen int
|
||||
ToolCalls int
|
||||
HasReasoning bool
|
||||
}
|
||||
|
||||
// LLMDeltaPayload describes a streamed LLM delta.
|
||||
type LLMDeltaPayload struct {
|
||||
ContentDeltaLen int
|
||||
ReasoningDeltaLen int
|
||||
}
|
||||
|
||||
// LLMRetryPayload describes a retry of an LLM request.
|
||||
type LLMRetryPayload struct {
|
||||
Attempt int
|
||||
MaxRetries int
|
||||
Reason string
|
||||
Error string
|
||||
Backoff time.Duration
|
||||
}
|
||||
|
||||
// ContextCompressReason identifies why emergency compression ran.
|
||||
type ContextCompressReason string
|
||||
|
||||
const (
|
||||
// ContextCompressReasonProactive indicates compression before the first LLM call.
|
||||
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
|
||||
// ContextCompressReasonRetry indicates compression during context-error retry handling.
|
||||
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
|
||||
// ContextCompressReasonSummarize indicates post-turn async summarization.
|
||||
ContextCompressReasonSummarize ContextCompressReason = "summarize"
|
||||
)
|
||||
|
||||
// ContextCompressPayload describes a forced history compression.
|
||||
type ContextCompressPayload struct {
|
||||
Reason ContextCompressReason
|
||||
DroppedMessages int
|
||||
RemainingMessages int
|
||||
}
|
||||
|
||||
// SessionSummarizePayload describes a completed async session summarization.
|
||||
type SessionSummarizePayload struct {
|
||||
SummarizedMessages int
|
||||
KeptMessages int
|
||||
SummaryLen int
|
||||
OmittedOversized bool
|
||||
}
|
||||
|
||||
// ToolExecStartPayload describes a tool execution request.
|
||||
type ToolExecStartPayload struct {
|
||||
Tool string
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// ToolExecEndPayload describes the outcome of a tool execution.
|
||||
type ToolExecEndPayload struct {
|
||||
Tool string
|
||||
Duration time.Duration
|
||||
ForLLMLen int
|
||||
ForUserLen int
|
||||
IsError bool
|
||||
Async bool
|
||||
}
|
||||
|
||||
// ToolExecSkippedPayload describes a skipped tool call.
|
||||
type ToolExecSkippedPayload struct {
|
||||
Tool string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// SteeringInjectedPayload describes steering messages appended before the next LLM call.
|
||||
type SteeringInjectedPayload struct {
|
||||
Count int
|
||||
TotalContentLen int
|
||||
}
|
||||
|
||||
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
|
||||
type FollowUpQueuedPayload struct {
|
||||
SourceTool string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
type InterruptKind string
|
||||
|
||||
const (
|
||||
InterruptKindSteering InterruptKind = "steering"
|
||||
InterruptKindGraceful InterruptKind = "graceful"
|
||||
InterruptKindHard InterruptKind = "hard_abort"
|
||||
)
|
||||
|
||||
// InterruptReceivedPayload describes accepted turn-control input.
|
||||
type InterruptReceivedPayload struct {
|
||||
Kind InterruptKind
|
||||
Role string
|
||||
ContentLen int
|
||||
QueueDepth int
|
||||
HintLen int
|
||||
}
|
||||
|
||||
// SubTurnSpawnPayload describes the creation of a child turn.
|
||||
type SubTurnSpawnPayload struct {
|
||||
AgentID string
|
||||
Label string
|
||||
ParentTurnID string
|
||||
}
|
||||
|
||||
// SubTurnEndPayload describes the completion of a child turn.
|
||||
type SubTurnEndPayload struct {
|
||||
AgentID string
|
||||
Status string
|
||||
}
|
||||
|
||||
// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
|
||||
type SubTurnResultDeliveredPayload struct {
|
||||
TargetChannel string
|
||||
TargetChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
// SubTurnOrphanPayload describes a sub-turn result that could not be delivered.
|
||||
type SubTurnOrphanPayload struct {
|
||||
ParentTurnID string
|
||||
ChildTurnID string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// ErrorPayload describes an execution error inside the agent loop.
|
||||
type ErrorPayload struct {
|
||||
Stage string
|
||||
Message string
|
||||
}
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package agent
|
||||
|
||||
import runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
|
||||
func (al *AgentLoop) publishRuntimeEvent(evt runtimeevents.Event) {
|
||||
if al == nil || al.runtimeEvents == nil {
|
||||
return
|
||||
}
|
||||
|
||||
al.runtimeEvents.PublishNonBlocking(evt)
|
||||
}
|
||||
|
||||
func runtimeScopeFromHookMeta(meta HookMeta, eventCtx *TurnContext) runtimeevents.Scope {
|
||||
scope := runtimeevents.Scope{
|
||||
AgentID: meta.AgentID,
|
||||
SessionKey: meta.SessionKey,
|
||||
TurnID: meta.TurnID,
|
||||
}
|
||||
|
||||
if eventCtx == nil || eventCtx.Inbound == nil {
|
||||
return scope
|
||||
}
|
||||
|
||||
inbound := eventCtx.Inbound
|
||||
scope.Channel = inbound.Channel
|
||||
scope.Account = inbound.Account
|
||||
scope.ChatID = inbound.ChatID
|
||||
scope.TopicID = inbound.TopicID
|
||||
scope.SpaceID = inbound.SpaceID
|
||||
scope.SpaceType = inbound.SpaceType
|
||||
scope.ChatType = inbound.ChatType
|
||||
scope.SenderID = inbound.SenderID
|
||||
scope.MessageID = inbound.MessageID
|
||||
return scope
|
||||
}
|
||||
|
||||
func runtimeCorrelationFromHookMeta(meta HookMeta) runtimeevents.Correlation {
|
||||
return runtimeevents.Correlation{
|
||||
TraceID: meta.TracePath,
|
||||
ParentTurnID: meta.ParentTurnID,
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeSeverityForAgentEvent(kind runtimeevents.Kind, payload any) runtimeevents.Severity {
|
||||
switch kind {
|
||||
case runtimeevents.KindAgentError, runtimeevents.KindAgentSubTurnOrphan:
|
||||
return runtimeevents.SeverityError
|
||||
case runtimeevents.KindAgentLLMRetry,
|
||||
runtimeevents.KindAgentContextCompress,
|
||||
runtimeevents.KindAgentToolExecSkipped:
|
||||
return runtimeevents.SeverityWarn
|
||||
case runtimeevents.KindAgentTurnEnd:
|
||||
payload, ok := payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
return runtimeevents.SeverityInfo
|
||||
}
|
||||
switch payload.Status {
|
||||
case TurnEndStatusError:
|
||||
return runtimeevents.SeverityError
|
||||
case TurnEndStatusAborted:
|
||||
return runtimeevents.SeverityWarn
|
||||
default:
|
||||
return runtimeevents.SeverityInfo
|
||||
}
|
||||
case runtimeevents.KindAgentToolExecEnd:
|
||||
payload, ok := payload.(ToolExecEndPayload)
|
||||
if ok && payload.IsError {
|
||||
return runtimeevents.SeverityWarn
|
||||
}
|
||||
return runtimeevents.SeverityInfo
|
||||
default:
|
||||
return runtimeevents.SeverityInfo
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeAttrsFromHookMeta(meta HookMeta) map[string]any {
|
||||
attrs := make(map[string]any, 2)
|
||||
if meta.Source != "" {
|
||||
attrs["agent_source"] = meta.Source
|
||||
}
|
||||
if meta.Iteration != 0 {
|
||||
attrs["iteration"] = meta.Iteration
|
||||
}
|
||||
if len(attrs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
+28
-6
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
type hookRuntime struct {
|
||||
@@ -295,10 +296,11 @@ func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error)
|
||||
case "", "*", "all":
|
||||
return nil, true, nil
|
||||
default:
|
||||
if _, ok := validKinds[kind]; !ok {
|
||||
normalizedKind, ok := validKinds[kind]
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
|
||||
}
|
||||
normalized = append(normalized, kind)
|
||||
normalized = append(normalized, normalizedKind)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,10 +310,30 @@ func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error)
|
||||
return normalized, true, nil
|
||||
}
|
||||
|
||||
func validHookEventKinds() map[string]struct{} {
|
||||
kinds := make(map[string]struct{}, int(eventKindCount))
|
||||
for kind := EventKind(0); kind < eventKindCount; kind++ {
|
||||
kinds[kind.String()] = struct{}{}
|
||||
func validHookEventKinds() map[string]string {
|
||||
runtimeKinds := runtimeevents.KnownKinds()
|
||||
kinds := make(map[string]string, len(runtimeKinds)*2)
|
||||
for _, kind := range runtimeKinds {
|
||||
kinds[kind.String()] = kind.String()
|
||||
}
|
||||
kinds["turn_start"] = runtimeevents.KindAgentTurnStart.String()
|
||||
kinds["turn_end"] = runtimeevents.KindAgentTurnEnd.String()
|
||||
kinds["llm_request"] = runtimeevents.KindAgentLLMRequest.String()
|
||||
kinds["llm_delta"] = runtimeevents.KindAgentLLMDelta.String()
|
||||
kinds["llm_response"] = runtimeevents.KindAgentLLMResponse.String()
|
||||
kinds["llm_retry"] = runtimeevents.KindAgentLLMRetry.String()
|
||||
kinds["context_compress"] = runtimeevents.KindAgentContextCompress.String()
|
||||
kinds["session_summarize"] = runtimeevents.KindAgentSessionSummarize.String()
|
||||
kinds["tool_exec_start"] = runtimeevents.KindAgentToolExecStart.String()
|
||||
kinds["tool_exec_end"] = runtimeevents.KindAgentToolExecEnd.String()
|
||||
kinds["tool_exec_skipped"] = runtimeevents.KindAgentToolExecSkipped.String()
|
||||
kinds["steering_injected"] = runtimeevents.KindAgentSteeringInjected.String()
|
||||
kinds["follow_up_queued"] = runtimeevents.KindAgentFollowUpQueued.String()
|
||||
kinds["interrupt_received"] = runtimeevents.KindAgentInterruptReceived.String()
|
||||
kinds["subturn_spawn"] = runtimeevents.KindAgentSubTurnSpawn.String()
|
||||
kinds["subturn_end"] = runtimeevents.KindAgentSubTurnEnd.String()
|
||||
kinds["subturn_result_delivered"] = runtimeevents.KindAgentSubTurnResultDelivered.String()
|
||||
kinds["subturn_orphan"] = runtimeevents.KindAgentSubTurnOrphan.String()
|
||||
kinds["error"] = runtimeevents.KindAgentError.String()
|
||||
return kinds
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -155,7 +156,27 @@ func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T)
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
waitForFileContains(t, eventLog, "agent.turn.end")
|
||||
}
|
||||
|
||||
func TestProcessHookObserveKindsFromConfigAcceptsRuntimeNames(t *testing.T) {
|
||||
kinds, enabled, err := processHookObserveKindsFromConfig([]string{
|
||||
"tool_exec_start",
|
||||
"agent.tool.exec_end",
|
||||
"gateway.ready",
|
||||
"mcp.server.failed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processHookObserveKindsFromConfig failed: %v", err)
|
||||
}
|
||||
if !enabled {
|
||||
t.Fatal("expected observe to be enabled")
|
||||
}
|
||||
|
||||
want := []string{"agent.tool.exec_start", "agent.tool.exec_end", "gateway.ready", "mcp.server.failed"}
|
||||
if !slices.Equal(kinds, want) {
|
||||
t.Fatalf("observe kinds = %v, want %v", kinds, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -183,7 +184,7 @@ func (ph *ProcessHook) Close() error {
|
||||
return ph.closeErr
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
func (ph *ProcessHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
|
||||
if ph == nil || !ph.opts.Observe {
|
||||
return nil
|
||||
}
|
||||
@@ -192,7 +193,7 @@ func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ph.notify(ctx, "hook.event", evt)
|
||||
return ph.notify(ctx, "hook.runtime_event", evt)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) BeforeLLM(
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
@@ -66,7 +67,7 @@ func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
waitForFileContains(t, eventLog, "agent.turn.end")
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
|
||||
@@ -146,8 +147,13 @@ func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
@@ -167,8 +173,8 @@ func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
skippedEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
@@ -350,12 +356,11 @@ func runProcessHookHelper() error {
|
||||
}
|
||||
|
||||
if msg.ID == 0 {
|
||||
if msg.Method == "hook.event" && eventLog != "" {
|
||||
if msg.Method == "hook.runtime_event" && eventLog != "" {
|
||||
var evt map[string]any
|
||||
if err := json.Unmarshal(msg.Params, &evt); err == nil {
|
||||
if rawKind, ok := evt["Kind"].(float64); ok {
|
||||
kind := EventKind(rawKind)
|
||||
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
|
||||
if kind, ok := evt["kind"].(string); ok {
|
||||
_ = os.WriteFile(eventLog, []byte(kind+"\n"), 0o644)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+56
-36
@@ -9,6 +9,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -71,8 +72,8 @@ func NamedHook(name string, hook any) HookRegistration {
|
||||
}
|
||||
}
|
||||
|
||||
type EventObserver interface {
|
||||
OnEvent(ctx context.Context, evt Event) error
|
||||
type RuntimeEventObserver interface {
|
||||
OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error
|
||||
}
|
||||
|
||||
type LLMInterceptor interface {
|
||||
@@ -90,7 +91,7 @@ type ToolApprover interface {
|
||||
}
|
||||
|
||||
type LLMHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Meta HookMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []providers.Message `json:"messages,omitempty"`
|
||||
@@ -104,7 +105,7 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Meta = cloneHookMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Messages = cloneProviderMessages(r.Messages)
|
||||
cloned.Tools = cloneToolDefinitions(r.Tools)
|
||||
@@ -113,7 +114,7 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
}
|
||||
|
||||
type LLMHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Meta HookMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Response *providers.LLMResponse `json:"response,omitempty"`
|
||||
@@ -124,14 +125,14 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Meta = cloneHookMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Response = cloneLLMResponse(r.Response)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Meta HookMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
@@ -145,7 +146,7 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Meta = cloneHookMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.HookResult = cloneToolResult(r.HookResult)
|
||||
@@ -153,7 +154,7 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
}
|
||||
|
||||
type ToolApprovalRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Meta HookMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
@@ -164,14 +165,14 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Meta = cloneHookMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolResultHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Meta HookMeta `json:"meta"`
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
@@ -184,7 +185,7 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Meta = cloneEventMeta(r.Meta)
|
||||
cloned.Meta = cloneHookMeta(r.Meta)
|
||||
cloned.Context = cloneTurnContext(r.Context)
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.Result = cloneToolResult(r.Result)
|
||||
@@ -192,7 +193,7 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
}
|
||||
|
||||
type HookManager struct {
|
||||
eventBus *EventBus
|
||||
runtimeEvents runtimeevents.EventChannel
|
||||
observerTimeout time.Duration
|
||||
interceptorTimeout time.Duration
|
||||
approvalTimeout time.Duration
|
||||
@@ -201,28 +202,39 @@ type HookManager struct {
|
||||
hooks map[string]HookRegistration
|
||||
ordered []HookRegistration
|
||||
|
||||
sub EventSubscription
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
runtimeSub runtimeevents.Subscription
|
||||
runtimeDone chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHookManager(eventBus *EventBus) *HookManager {
|
||||
func NewHookManager(runtimeEvents runtimeevents.EventChannel) *HookManager {
|
||||
hm := &HookManager{
|
||||
eventBus: eventBus,
|
||||
runtimeEvents: runtimeEvents,
|
||||
observerTimeout: defaultHookObserverTimeout,
|
||||
interceptorTimeout: defaultHookInterceptorTimeout,
|
||||
approvalTimeout: defaultHookApprovalTimeout,
|
||||
hooks: make(map[string]HookRegistration),
|
||||
done: make(chan struct{}),
|
||||
runtimeDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
if eventBus == nil {
|
||||
close(hm.done)
|
||||
return hm
|
||||
if runtimeEvents != nil {
|
||||
sub, ch, err := runtimeEvents.SubscribeChan(context.Background(), runtimeevents.SubscribeOptions{
|
||||
Name: "hook-manager-observer",
|
||||
Buffer: hookObserverBufferSize,
|
||||
})
|
||||
if err != nil {
|
||||
logger.WarnCF("hooks", "Failed to subscribe runtime events for hooks", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
close(hm.runtimeDone)
|
||||
} else {
|
||||
hm.runtimeSub = sub
|
||||
go hm.dispatchRuntimeEvents(ch)
|
||||
}
|
||||
} else {
|
||||
close(hm.runtimeDone)
|
||||
}
|
||||
|
||||
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
|
||||
go hm.dispatchEvents()
|
||||
return hm
|
||||
}
|
||||
|
||||
@@ -232,10 +244,14 @@ func (hm *HookManager) Close() {
|
||||
}
|
||||
|
||||
hm.closeOnce.Do(func() {
|
||||
if hm.eventBus != nil {
|
||||
hm.eventBus.Unsubscribe(hm.sub.ID)
|
||||
if hm.runtimeSub != nil {
|
||||
if err := hm.runtimeSub.Close(); err != nil {
|
||||
logger.WarnCF("hooks", "Failed to close runtime event hook subscription", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
<-hm.done
|
||||
<-hm.runtimeDone
|
||||
hm.closeAllHooks()
|
||||
})
|
||||
}
|
||||
@@ -292,16 +308,16 @@ func (hm *HookManager) Unmount(name string) {
|
||||
hm.rebuildOrdered()
|
||||
}
|
||||
|
||||
func (hm *HookManager) dispatchEvents() {
|
||||
defer close(hm.done)
|
||||
func (hm *HookManager) dispatchRuntimeEvents(ch <-chan runtimeevents.Event) {
|
||||
defer close(hm.runtimeDone)
|
||||
|
||||
for evt := range hm.sub.C {
|
||||
for evt := range ch {
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
observer, ok := reg.Hook.(EventObserver)
|
||||
observer, ok := reg.Hook.(RuntimeEventObserver)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hm.runObserver(reg.Name, observer, evt)
|
||||
hm.runRuntimeObserver(reg.Name, observer, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -581,26 +597,30 @@ func (hm *HookManager) closeAllHooks() {
|
||||
hm.ordered = nil
|
||||
}
|
||||
|
||||
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
|
||||
func (hm *HookManager) runRuntimeObserver(
|
||||
name string,
|
||||
observer RuntimeEventObserver,
|
||||
evt runtimeevents.Event,
|
||||
) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- observer.OnEvent(ctx, evt)
|
||||
done <- observer.OnRuntimeEvent(ctx, evt)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.WarnCF("hooks", "Event observer failed", map[string]any{
|
||||
logger.WarnCF("hooks", "Runtime event observer failed", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
|
||||
logger.WarnCF("hooks", "Runtime event observer timed out", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"timeout_ms": hm.observerTimeout.Milliseconds(),
|
||||
|
||||
+137
-51
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
@@ -111,14 +112,14 @@ func (p *llmHookTestProvider) GetDefaultModel() string {
|
||||
}
|
||||
|
||||
type llmObserverHook struct {
|
||||
eventCh chan Event
|
||||
eventCh chan runtimeevents.Event
|
||||
lastInbound *bus.InboundContext
|
||||
lastRoute *routing.ResolvedRoute
|
||||
lastScope *session.SessionScope
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if evt.Kind == EventKindTurnEnd {
|
||||
func (h *llmObserverHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
|
||||
if evt.Kind == runtimeevents.KindAgentTurnEnd {
|
||||
select {
|
||||
case h.eventCh <- evt:
|
||||
default:
|
||||
@@ -150,6 +151,20 @@ func (h *llmObserverHook) AfterLLM(
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
type dualRuntimeObserverHook struct {
|
||||
runtimeCh chan runtimeevents.Event
|
||||
}
|
||||
|
||||
func (h *dualRuntimeObserverHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
|
||||
if evt.Kind == runtimeevents.KindAgentTurnEnd {
|
||||
select {
|
||||
case h.runtimeCh <- evt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type llmSystemRewriteHook struct{}
|
||||
|
||||
func (h *llmSystemRewriteHook) BeforeLLM(
|
||||
@@ -417,7 +432,7 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
||||
hook := &llmObserverHook{eventCh: make(chan runtimeevents.Event, 1)}
|
||||
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
@@ -481,30 +496,80 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
|
||||
select {
|
||||
case evt := <-hook.eventCh:
|
||||
if evt.Kind != EventKindTurnEnd {
|
||||
if evt.Kind != runtimeevents.KindAgentTurnEnd {
|
||||
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
||||
}
|
||||
if evt.Context == nil || evt.Context.Inbound == nil {
|
||||
t.Fatal("expected observer event to carry inbound context")
|
||||
}
|
||||
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
|
||||
t.Fatalf("expected observer event to carry route context, got %+v", evt.Context.Route)
|
||||
}
|
||||
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "hook-user" {
|
||||
t.Fatalf("expected observer event to carry session scope, got %+v", evt.Context.Scope)
|
||||
if evt.Scope.AgentID != "main" ||
|
||||
evt.Scope.SessionKey != "session-1" ||
|
||||
evt.Scope.Channel != "cli" ||
|
||||
evt.Scope.ChatID != "direct" ||
|
||||
evt.Scope.SenderID != "hook-user" {
|
||||
t.Fatalf("runtime observer scope = %+v", evt.Scope)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for hook observer event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_RuntimeObserverReceivesEvents(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &dualRuntimeObserverHook{
|
||||
runtimeCh: make(chan runtimeevents.Event, 1),
|
||||
}
|
||||
if err := al.MountHook(NamedHook("runtime-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
InboundContext: &bus.InboundContext{
|
||||
Channel: "cli",
|
||||
Account: "default",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "hook-user",
|
||||
MessageID: "msg-1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "provider content" {
|
||||
t.Fatalf("expected provider content, got %q", resp)
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.runtimeCh:
|
||||
if evt.Kind != runtimeevents.KindAgentTurnEnd {
|
||||
t.Fatalf("runtime observer kind = %q", evt.Kind)
|
||||
}
|
||||
if evt.Scope.SessionKey != "session-1" ||
|
||||
evt.Scope.Channel != "cli" ||
|
||||
evt.Scope.ChatID != "direct" ||
|
||||
evt.Scope.MessageID != "msg-1" {
|
||||
t.Fatalf("runtime observer scope = %+v", evt.Scope)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for runtime observer event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_BtwCommand_UsesLLMHooks(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
||||
hook := &llmObserverHook{eventCh: make(chan runtimeevents.Event, 1)}
|
||||
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
@@ -800,8 +865,13 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
@@ -820,8 +890,8 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
skippedEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
@@ -876,8 +946,13 @@ func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentToolExecEnd,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
@@ -899,8 +974,8 @@ func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify event stream has ToolExecEnd, not actual tool execution
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected tool exec end event")
|
||||
}
|
||||
@@ -1065,8 +1140,13 @@ func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
|
||||
sendErr: errors.New("channel unavailable"),
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentToolExecEnd,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-media-err",
|
||||
@@ -1081,8 +1161,8 @@ func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected ToolExecEnd event")
|
||||
}
|
||||
@@ -1120,8 +1200,13 @@ func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentToolExecEnd,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-bus-fallback",
|
||||
@@ -1136,8 +1221,8 @@ func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected ToolExecEnd event")
|
||||
}
|
||||
@@ -1282,8 +1367,13 @@ func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
32,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
@@ -1322,9 +1412,9 @@ func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
|
||||
t.Fatal("timeout waiting for result")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
|
||||
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
||||
skippedEvts := filterRuntimeEvents(events, runtimeevents.KindAgentToolExecSkipped)
|
||||
if len(skippedEvts) < 1 {
|
||||
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
|
||||
}
|
||||
@@ -1362,8 +1452,14 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
32,
|
||||
runtimeevents.KindAgentToolExecEnd,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
@@ -1383,14 +1479,14 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
collectedEvents := make([]Event, 0, 8)
|
||||
collectedEvents := make([]runtimeevents.Event, 0, 8)
|
||||
steered := false
|
||||
deadline := time.After(3 * time.Second)
|
||||
for !steered {
|
||||
select {
|
||||
case evt := <-sub.C:
|
||||
case evt := <-runtimeCh:
|
||||
collectedEvents = append(collectedEvents, evt)
|
||||
if evt.Kind != EventKindToolExecEnd {
|
||||
if evt.Kind != runtimeevents.KindAgentToolExecEnd {
|
||||
continue
|
||||
}
|
||||
payload, ok := evt.Payload.(ToolExecEndPayload)
|
||||
@@ -1413,9 +1509,9 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
||||
t.Fatal("timeout waiting for result")
|
||||
}
|
||||
|
||||
events := append(collectedEvents, collectEventStream(sub.C)...)
|
||||
events := append(collectedEvents, collectRuntimeEventStream(runtimeCh)...)
|
||||
|
||||
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
||||
skippedEvts := filterRuntimeEvents(events, runtimeevents.KindAgentToolExecSkipped)
|
||||
if len(skippedEvts) < 1 {
|
||||
t.Fatal("expected at least one ToolExecSkipped event after steering")
|
||||
}
|
||||
@@ -1480,13 +1576,3 @@ func TestCloneStringAnyMap_EmptyMapReturnsNonNil(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func filterEvents(events []Event, kind EventKind) []Event {
|
||||
var result []Event
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
result = append(result, evt)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -44,4 +44,11 @@ type ChannelManager interface {
|
||||
|
||||
// SendPlaceholder sends a placeholder message (e.g., for audio transcription).
|
||||
SendPlaceholder(ctx context.Context, channel, chatID string) bool
|
||||
|
||||
// DismissToolFeedback clears any tracked tool feedback animation for the
|
||||
// given channel/chat. Call this when a turn ends without a final response
|
||||
// (e.g., ResponseHandled tools) to avoid orphaned animation goroutines.
|
||||
// outboundCtx carries topic/thread info needed for channels that use
|
||||
// scoped tracker keys (e.g., Telegram forum topics); may be nil.
|
||||
DismissToolFeedback(ctx context.Context, channel, chatID string, outboundCtx *bus.InboundContext)
|
||||
}
|
||||
|
||||
@@ -56,5 +56,12 @@ func isVisionUnsupportedError(err error) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// DeepSeek and other strict providers reject the image_url field at the
|
||||
// JSON schema level with an "unknown variant" error rather than a semantic
|
||||
// "not supported" message.
|
||||
if strings.Contains(msg, "unknown variant") && strings.Contains(msg, "image_url") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -72,7 +73,7 @@ toolLoop:
|
||||
})
|
||||
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
runtimeevents.KindAgentToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: toolName,
|
||||
@@ -191,7 +192,7 @@ toolLoop:
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
runtimeevents.KindAgentToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolName,
|
||||
@@ -237,7 +238,7 @@ toolLoop:
|
||||
for j := i + 1; j < len(normalizedToolCalls); j++ {
|
||||
skippedTC := normalizedToolCalls[j]
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: skippedTC.Name,
|
||||
@@ -284,7 +285,7 @@ toolLoop:
|
||||
exec.allResponsesHandled = false
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
@@ -323,7 +324,7 @@ toolLoop:
|
||||
exec.allResponsesHandled = false
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
@@ -353,7 +354,7 @@ toolLoop:
|
||||
"iteration": iteration,
|
||||
})
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
runtimeevents.KindAgentToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: toolName,
|
||||
@@ -401,7 +402,7 @@ toolLoop:
|
||||
"channel": ts.channel,
|
||||
})
|
||||
al.emitEvent(
|
||||
EventKindFollowUpQueued,
|
||||
runtimeevents.KindAgentFollowUpQueued,
|
||||
ts.scope.meta(iteration, "runTurn", "turn.follow_up.queued"),
|
||||
FollowUpQueuedPayload{
|
||||
SourceTool: asyncToolName,
|
||||
@@ -567,7 +568,7 @@ toolLoop:
|
||||
toolResultMsg.Media = append(toolResultMsg.Media, toolResult.Media...)
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
runtimeevents.KindAgentToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolName,
|
||||
@@ -612,7 +613,7 @@ toolLoop:
|
||||
for j := i + 1; j < len(normalizedToolCalls); j++ {
|
||||
skippedTC := normalizedToolCalls[j]
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
runtimeevents.KindAgentToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: skippedTC.Name,
|
||||
@@ -704,6 +705,9 @@ toolLoop:
|
||||
}
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
ts.setFinalContent("")
|
||||
if al.channelManager != nil && ts.channel != "" {
|
||||
al.channelManager.DismissToolFeedback(ctx, ts.channel, ts.chatID, ts.opts.InboundContext)
|
||||
}
|
||||
logger.InfoCF("agent", "Tool output satisfied delivery; ending turn without follow-up LLM",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -50,7 +51,7 @@ func (p *Pipeline) Finalize(
|
||||
ts.ingestMessage(turnCtx, al, finalMsg)
|
||||
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
runtimeevents.KindAgentError,
|
||||
ts.eventMeta("runTurn", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "session_save",
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
@@ -113,7 +114,7 @@ func (p *Pipeline) CallLLM(
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRequest,
|
||||
runtimeevents.KindAgentLLMRequest,
|
||||
ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
LLMRequestPayload{
|
||||
Model: exec.llmModel,
|
||||
@@ -184,7 +185,14 @@ func (p *Pipeline) CallLLM(
|
||||
|
||||
// Retry loop
|
||||
var err error
|
||||
maxRetries := 2
|
||||
maxRetries := p.Cfg.Agents.Defaults.MaxLLMRetries
|
||||
if maxRetries <= 0 {
|
||||
maxRetries = 2
|
||||
}
|
||||
backoffSecs := p.Cfg.Agents.Defaults.LLMRetryBackoffSecs
|
||||
if backoffSecs <= 0 {
|
||||
backoffSecs = 2
|
||||
}
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
exec.response, err = callLLM(exec.callMessages, exec.providerToolDefs)
|
||||
if err == nil {
|
||||
@@ -199,7 +207,7 @@ func (p *Pipeline) CallLLM(
|
||||
// Retry without media if vision is unsupported
|
||||
if hasMediaRefs(exec.callMessages) && isVisionUnsupportedError(err) && retry < maxRetries {
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
runtimeevents.KindAgentLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: retry + 1,
|
||||
@@ -232,6 +240,15 @@ func (p *Pipeline) CallLLM(
|
||||
strings.Contains(errMsg, "timed out") ||
|
||||
strings.Contains(errMsg, "timeout exceeded")
|
||||
|
||||
isNetworkError := !isTimeoutError && (strings.Contains(errMsg, "connection reset") ||
|
||||
strings.Contains(errMsg, "connection refused") ||
|
||||
strings.Contains(errMsg, "broken pipe") ||
|
||||
strings.Contains(errMsg, "no such host") ||
|
||||
strings.Contains(errMsg, "network is unreachable") ||
|
||||
strings.Contains(errMsg, "read tcp") ||
|
||||
strings.Contains(errMsg, "write tcp") ||
|
||||
strings.Contains(errMsg, "eof"))
|
||||
|
||||
isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
|
||||
strings.Contains(errMsg, "context window") ||
|
||||
strings.Contains(errMsg, "context_window") ||
|
||||
@@ -244,9 +261,9 @@ func (p *Pipeline) CallLLM(
|
||||
strings.Contains(errMsg, "request too large"))
|
||||
|
||||
if isTimeoutError && retry < maxRetries {
|
||||
backoff := time.Duration(retry+1) * 5 * time.Second
|
||||
backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
runtimeevents.KindAgentLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: retry + 1,
|
||||
@@ -272,9 +289,38 @@ func (p *Pipeline) CallLLM(
|
||||
continue
|
||||
}
|
||||
|
||||
if isNetworkError && retry < maxRetries {
|
||||
backoff := time.Duration(retry+1) * time.Duration(backoffSecs) * time.Second
|
||||
al.emitEvent(
|
||||
runtimeevents.KindAgentLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: retry + 1,
|
||||
MaxRetries: maxRetries,
|
||||
Reason: "network",
|
||||
Error: err.Error(),
|
||||
Backoff: backoff,
|
||||
},
|
||||
)
|
||||
logger.WarnCF("agent", "Network error, retrying after backoff", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
"backoff": backoff.String(),
|
||||
})
|
||||
if sleepErr := sleepWithContext(turnCtx, backoff); sleepErr != nil {
|
||||
if ts.hardAbortRequested() {
|
||||
_ = ts.requestHardAbort()
|
||||
return ControlBreak, nil
|
||||
}
|
||||
err = sleepErr
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isContextError && retry < maxRetries && !ts.opts.NoHistory {
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
runtimeevents.KindAgentLLMRetry,
|
||||
ts.eventMeta("runTurn", "turn.llm.retry"),
|
||||
LLMRetryPayload{
|
||||
Attempt: retry + 1,
|
||||
@@ -333,7 +379,7 @@ func (p *Pipeline) CallLLM(
|
||||
|
||||
if err != nil {
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
runtimeevents.KindAgentError,
|
||||
ts.eventMeta("runTurn", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "llm",
|
||||
@@ -397,7 +443,7 @@ func (p *Pipeline) CallLLM(
|
||||
)
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindLLMResponse,
|
||||
runtimeevents.KindAgentLLMResponse,
|
||||
ts.eventMeta("runTurn", "turn.llm.response"),
|
||||
LLMResponsePayload{
|
||||
ContentLen: len(exec.response.Content),
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
runtimeEventLoggerBuffer = 256
|
||||
runtimeEventLoggerDrainTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type runtimeEventLogger struct {
|
||||
mu sync.RWMutex
|
||||
cfg config.EventLoggingConfig
|
||||
}
|
||||
|
||||
func (al *AgentLoop) refreshRuntimeEventLogger(cfg *config.Config) {
|
||||
if al == nil {
|
||||
return
|
||||
}
|
||||
logCfg := config.EffectiveEventLoggingConfig(cfg)
|
||||
|
||||
al.runtimeEventLogMu.Lock()
|
||||
if !logCfg.Enabled {
|
||||
oldSub := al.runtimeEventLogSub
|
||||
al.runtimeEventLogger = nil
|
||||
al.runtimeEventLogSub = nil
|
||||
al.runtimeEventLogMu.Unlock()
|
||||
closeRuntimeEventLoggerSubscription(oldSub)
|
||||
return
|
||||
}
|
||||
|
||||
if al.runtimeEventLogger != nil && al.runtimeEventLogSub != nil {
|
||||
al.runtimeEventLogger.updateConfig(logCfg)
|
||||
al.runtimeEventLogMu.Unlock()
|
||||
return
|
||||
}
|
||||
al.runtimeEventLogMu.Unlock()
|
||||
|
||||
eventLogger := newRuntimeEventLoggerFromConfig(logCfg)
|
||||
sub, err := eventLogger.subscribe(context.Background(), al.runtimeEvents)
|
||||
if err != nil {
|
||||
logger.WarnCF("events", "Failed to subscribe runtime event logger", map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
al.runtimeEventLogMu.Lock()
|
||||
oldSub := al.runtimeEventLogSub
|
||||
al.runtimeEventLogger = eventLogger
|
||||
al.runtimeEventLogSub = sub
|
||||
al.runtimeEventLogMu.Unlock()
|
||||
closeRuntimeEventLoggerSubscription(oldSub)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) closeRuntimeEventLogger() {
|
||||
if al == nil {
|
||||
return
|
||||
}
|
||||
al.runtimeEventLogMu.Lock()
|
||||
oldSub := al.runtimeEventLogSub
|
||||
al.runtimeEventLogger = nil
|
||||
al.runtimeEventLogSub = nil
|
||||
al.runtimeEventLogMu.Unlock()
|
||||
closeRuntimeEventLoggerSubscription(oldSub)
|
||||
}
|
||||
|
||||
func closeRuntimeEventLoggerSubscription(sub runtimeevents.Subscription) {
|
||||
if sub == nil {
|
||||
return
|
||||
}
|
||||
if err := sub.Close(); err != nil {
|
||||
logger.WarnCF("events", "Failed to close runtime event logger subscription", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
timer := time.NewTimer(runtimeEventLoggerDrainTimeout)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-sub.Done():
|
||||
case <-timer.C:
|
||||
logger.WarnCF("events", "Timed out waiting for runtime event logger to drain", map[string]any{
|
||||
"timeout": runtimeEventLoggerDrainTimeout.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newRuntimeEventLogger(cfg *config.Config) *runtimeEventLogger {
|
||||
logCfg := config.EffectiveEventLoggingConfig(cfg)
|
||||
if !logCfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
return newRuntimeEventLoggerFromConfig(logCfg)
|
||||
}
|
||||
|
||||
func newRuntimeEventLoggerFromConfig(logCfg config.EventLoggingConfig) *runtimeEventLogger {
|
||||
return &runtimeEventLogger{cfg: logCfg}
|
||||
}
|
||||
|
||||
func (l *runtimeEventLogger) updateConfig(cfg config.EventLoggingConfig) {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.cfg = cfg
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func (l *runtimeEventLogger) configSnapshot() config.EventLoggingConfig {
|
||||
if l == nil {
|
||||
return config.EventLoggingConfig{}
|
||||
}
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
return l.cfg
|
||||
}
|
||||
|
||||
func (l *runtimeEventLogger) subscribe(
|
||||
ctx context.Context,
|
||||
eventBus runtimeevents.Bus,
|
||||
) (runtimeevents.Subscription, error) {
|
||||
if l == nil || eventBus == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return eventBus.Channel().Subscribe(ctx, runtimeevents.SubscribeOptions{
|
||||
Name: "runtime-event-logger",
|
||||
Buffer: runtimeEventLoggerBuffer,
|
||||
Concurrency: runtimeevents.Locked,
|
||||
Backpressure: runtimeevents.DropNewest,
|
||||
PanicPolicy: runtimeevents.RecoverAndLog,
|
||||
}, l.handle)
|
||||
}
|
||||
|
||||
func (l *runtimeEventLogger) handle(_ context.Context, evt runtimeevents.Event) error {
|
||||
if l == nil || !l.shouldLog(evt) {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields := runtimeEventLogFields(evt)
|
||||
if l.configSnapshot().IncludePayload && evt.Payload != nil {
|
||||
fields["payload"] = evt.Payload
|
||||
}
|
||||
|
||||
logRuntimeEvent(evt, fields)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *runtimeEventLogger) shouldLog(evt runtimeevents.Event) bool {
|
||||
if l == nil {
|
||||
return false
|
||||
}
|
||||
cfg := l.configSnapshot()
|
||||
if !cfg.Enabled {
|
||||
return false
|
||||
}
|
||||
if runtimeEventSeverityRank(evt.Severity) < runtimeEventSeverityRank(parseRuntimeEventSeverity(cfg.MinSeverity)) {
|
||||
return false
|
||||
}
|
||||
|
||||
kind := evt.Kind.String()
|
||||
if !matchAnyRuntimeEventPattern(cfg.Include, kind, true) {
|
||||
return false
|
||||
}
|
||||
return !matchAnyRuntimeEventPattern(cfg.Exclude, kind, false)
|
||||
}
|
||||
|
||||
func logRuntimeEvent(evt runtimeevents.Event, fields map[string]any) {
|
||||
message := fmt.Sprintf("Runtime event: %s", evt.Kind.String())
|
||||
switch normalizeRuntimeEventSeverity(evt.Severity) {
|
||||
case runtimeevents.SeverityDebug:
|
||||
logger.DebugCF("events", message, fields)
|
||||
case runtimeevents.SeverityWarn:
|
||||
logger.WarnCF("events", message, fields)
|
||||
case runtimeevents.SeverityError:
|
||||
logger.ErrorCF("events", message, fields)
|
||||
default:
|
||||
logger.InfoCF("events", message, fields)
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeEventLogFields(evt runtimeevents.Event) map[string]any {
|
||||
fields := map[string]any{
|
||||
"event_id": evt.ID,
|
||||
"event_kind": evt.Kind.String(),
|
||||
"severity": string(normalizeRuntimeEventSeverity(evt.Severity)),
|
||||
}
|
||||
if !evt.Time.IsZero() {
|
||||
fields["event_time"] = evt.Time.Format(time.RFC3339Nano)
|
||||
}
|
||||
appendRuntimeEventSourceFields(fields, evt.Source)
|
||||
appendRuntimeEventScopeFields(fields, evt.Scope)
|
||||
appendRuntimeEventCorrelationFields(fields, evt.Correlation)
|
||||
appendRuntimeEventAttrs(fields, evt.Attrs)
|
||||
appendRuntimeEventPayloadSummary(fields, evt.Payload)
|
||||
return fields
|
||||
}
|
||||
|
||||
func appendRuntimeEventSourceFields(fields map[string]any, source runtimeevents.Source) {
|
||||
if source.Component != "" {
|
||||
fields["source_component"] = source.Component
|
||||
}
|
||||
if source.Name != "" {
|
||||
fields["source_name"] = source.Name
|
||||
}
|
||||
}
|
||||
|
||||
func appendRuntimeEventScopeFields(fields map[string]any, scope runtimeevents.Scope) {
|
||||
setStringField(fields, "runtime_id", scope.RuntimeID)
|
||||
setStringField(fields, "agent_id", scope.AgentID)
|
||||
setStringField(fields, "session_key", scope.SessionKey)
|
||||
setStringField(fields, "turn_id", scope.TurnID)
|
||||
setStringField(fields, "channel", scope.Channel)
|
||||
setStringField(fields, "account", scope.Account)
|
||||
setStringField(fields, "chat_id", scope.ChatID)
|
||||
setStringField(fields, "topic_id", scope.TopicID)
|
||||
setStringField(fields, "space_id", scope.SpaceID)
|
||||
setStringField(fields, "space_type", scope.SpaceType)
|
||||
setStringField(fields, "chat_type", scope.ChatType)
|
||||
setStringField(fields, "sender_id", scope.SenderID)
|
||||
setStringField(fields, "message_id", scope.MessageID)
|
||||
}
|
||||
|
||||
func appendRuntimeEventCorrelationFields(fields map[string]any, correlation runtimeevents.Correlation) {
|
||||
setStringField(fields, "trace_id", correlation.TraceID)
|
||||
setStringField(fields, "parent_turn_id", correlation.ParentTurnID)
|
||||
setStringField(fields, "request_id", correlation.RequestID)
|
||||
setStringField(fields, "reply_to_id", correlation.ReplyToID)
|
||||
}
|
||||
|
||||
func appendRuntimeEventAttrs(fields map[string]any, attrs map[string]any) {
|
||||
for key, value := range attrs {
|
||||
if key == "" || value == nil {
|
||||
continue
|
||||
}
|
||||
if _, exists := fields[key]; exists {
|
||||
fields["attr_"+key] = value
|
||||
continue
|
||||
}
|
||||
fields[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func appendRuntimeEventPayloadSummary(fields map[string]any, payload any) {
|
||||
switch payload := payload.(type) {
|
||||
case TurnStartPayload:
|
||||
fields["user_len"] = len(payload.UserMessage)
|
||||
fields["media_count"] = payload.MediaCount
|
||||
case TurnEndPayload:
|
||||
fields["status"] = payload.Status
|
||||
fields["iterations_total"] = payload.Iterations
|
||||
fields["duration_ms"] = payload.Duration.Milliseconds()
|
||||
fields["final_len"] = payload.FinalContentLen
|
||||
case LLMRequestPayload:
|
||||
fields["model"] = payload.Model
|
||||
fields["messages"] = payload.MessagesCount
|
||||
fields["tools"] = payload.ToolsCount
|
||||
fields["max_tokens"] = payload.MaxTokens
|
||||
case LLMDeltaPayload:
|
||||
fields["content_delta_len"] = payload.ContentDeltaLen
|
||||
fields["reasoning_delta_len"] = payload.ReasoningDeltaLen
|
||||
case LLMResponsePayload:
|
||||
fields["content_len"] = payload.ContentLen
|
||||
fields["tool_calls"] = payload.ToolCalls
|
||||
fields["has_reasoning"] = payload.HasReasoning
|
||||
case LLMRetryPayload:
|
||||
fields["attempt"] = payload.Attempt
|
||||
fields["max_retries"] = payload.MaxRetries
|
||||
fields["reason"] = payload.Reason
|
||||
fields["error"] = payload.Error
|
||||
fields["backoff_ms"] = payload.Backoff.Milliseconds()
|
||||
case ContextCompressPayload:
|
||||
fields["reason"] = payload.Reason
|
||||
fields["dropped_messages"] = payload.DroppedMessages
|
||||
fields["remaining_messages"] = payload.RemainingMessages
|
||||
case SessionSummarizePayload:
|
||||
fields["summarized_messages"] = payload.SummarizedMessages
|
||||
fields["kept_messages"] = payload.KeptMessages
|
||||
fields["summary_len"] = payload.SummaryLen
|
||||
fields["omitted_oversized"] = payload.OmittedOversized
|
||||
case ToolExecStartPayload:
|
||||
fields["tool"] = payload.Tool
|
||||
fields["args_count"] = len(payload.Arguments)
|
||||
case ToolExecEndPayload:
|
||||
fields["tool"] = payload.Tool
|
||||
fields["duration_ms"] = payload.Duration.Milliseconds()
|
||||
fields["for_llm_len"] = payload.ForLLMLen
|
||||
fields["for_user_len"] = payload.ForUserLen
|
||||
fields["is_error"] = payload.IsError
|
||||
fields["async"] = payload.Async
|
||||
case ToolExecSkippedPayload:
|
||||
fields["tool"] = payload.Tool
|
||||
fields["reason"] = payload.Reason
|
||||
case SteeringInjectedPayload:
|
||||
fields["count"] = payload.Count
|
||||
fields["total_content_len"] = payload.TotalContentLen
|
||||
case FollowUpQueuedPayload:
|
||||
fields["source_tool"] = payload.SourceTool
|
||||
fields["content_len"] = payload.ContentLen
|
||||
case InterruptReceivedPayload:
|
||||
fields["interrupt_kind"] = payload.Kind
|
||||
fields["role"] = payload.Role
|
||||
fields["content_len"] = payload.ContentLen
|
||||
fields["queue_depth"] = payload.QueueDepth
|
||||
fields["hint_len"] = payload.HintLen
|
||||
case SubTurnSpawnPayload:
|
||||
fields["child_agent_id"] = payload.AgentID
|
||||
fields["label"] = payload.Label
|
||||
case SubTurnEndPayload:
|
||||
fields["child_agent_id"] = payload.AgentID
|
||||
fields["status"] = payload.Status
|
||||
case SubTurnResultDeliveredPayload:
|
||||
fields["target_channel"] = payload.TargetChannel
|
||||
fields["target_chat_id"] = payload.TargetChatID
|
||||
fields["content_len"] = payload.ContentLen
|
||||
case SubTurnOrphanPayload:
|
||||
fields["parent_turn_id"] = payload.ParentTurnID
|
||||
fields["child_turn_id"] = payload.ChildTurnID
|
||||
fields["reason"] = payload.Reason
|
||||
case ErrorPayload:
|
||||
fields["stage"] = payload.Stage
|
||||
fields["error"] = payload.Message
|
||||
}
|
||||
}
|
||||
|
||||
func setStringField(fields map[string]any, key, value string) {
|
||||
if value != "" {
|
||||
fields[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func matchAnyRuntimeEventPattern(patterns []string, kind string, emptyMatches bool) bool {
|
||||
if len(patterns) == 0 {
|
||||
return emptyMatches
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
if matchRuntimeEventPattern(pattern, kind) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchRuntimeEventPattern(pattern, kind string) bool {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
if pattern == "" {
|
||||
return false
|
||||
}
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(pattern, ".*") {
|
||||
return strings.HasPrefix(kind, strings.TrimSuffix(pattern, "*"))
|
||||
}
|
||||
matched, err := path.Match(pattern, kind)
|
||||
if err == nil {
|
||||
return matched
|
||||
}
|
||||
return pattern == kind
|
||||
}
|
||||
|
||||
func parseRuntimeEventSeverity(severity string) runtimeevents.Severity {
|
||||
switch strings.ToLower(strings.TrimSpace(severity)) {
|
||||
case "debug":
|
||||
return runtimeevents.SeverityDebug
|
||||
case "warn", "warning":
|
||||
return runtimeevents.SeverityWarn
|
||||
case "error":
|
||||
return runtimeevents.SeverityError
|
||||
default:
|
||||
return runtimeevents.SeverityInfo
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeRuntimeEventSeverity(severity runtimeevents.Severity) runtimeevents.Severity {
|
||||
switch severity {
|
||||
case runtimeevents.SeverityDebug,
|
||||
runtimeevents.SeverityInfo,
|
||||
runtimeevents.SeverityWarn,
|
||||
runtimeevents.SeverityError:
|
||||
return severity
|
||||
default:
|
||||
return runtimeevents.SeverityInfo
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeEventSeverityRank(severity runtimeevents.Severity) int {
|
||||
switch normalizeRuntimeEventSeverity(severity) {
|
||||
case runtimeevents.SeverityDebug:
|
||||
return 0
|
||||
case runtimeevents.SeverityInfo:
|
||||
return 1
|
||||
case runtimeevents.SeverityWarn:
|
||||
return 2
|
||||
case runtimeevents.SeverityError:
|
||||
return 3
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func TestRuntimeEventLoggerFiltering(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
eventLogger := newRuntimeEventLogger(cfg)
|
||||
if eventLogger == nil {
|
||||
t.Fatal("default runtime event logger is nil")
|
||||
}
|
||||
|
||||
if !eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindAgentTurnStart,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
}) {
|
||||
t.Fatal("default config should log agent events")
|
||||
}
|
||||
if eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindChannelLifecycleStarted,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
}) {
|
||||
t.Fatal("default config should not log non-agent events")
|
||||
}
|
||||
|
||||
cfg.Events.Logging.Include = []string{"*"}
|
||||
cfg.Events.Logging.Exclude = []string{"mcp.*"}
|
||||
eventLogger = newRuntimeEventLogger(cfg)
|
||||
if !eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindGatewayReady,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
}) {
|
||||
t.Fatal("include * should log gateway events")
|
||||
}
|
||||
if eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindMCPServerConnected,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
}) {
|
||||
t.Fatal("exclude mcp.* should suppress MCP events")
|
||||
}
|
||||
|
||||
cfg.Events.Logging.Exclude = nil
|
||||
cfg.Events.Logging.MinSeverity = "warn"
|
||||
eventLogger = newRuntimeEventLogger(cfg)
|
||||
if eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindGatewayReady,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
}) {
|
||||
t.Fatal("min severity warn should suppress info events")
|
||||
}
|
||||
if !eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindGatewayReloadFailed,
|
||||
Severity: runtimeevents.SeverityError,
|
||||
}) {
|
||||
t.Fatal("min severity warn should allow error events")
|
||||
}
|
||||
|
||||
cfg.Events.Logging.Enabled = false
|
||||
if newRuntimeEventLogger(cfg) != nil {
|
||||
t.Fatal("disabled config should not create runtime event logger")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeEventLogFieldsSummarizeAgentPayload(t *testing.T) {
|
||||
fields := runtimeEventLogFields(runtimeevents.Event{
|
||||
ID: "evt-test",
|
||||
Kind: runtimeevents.KindAgentToolExecStart,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
Source: runtimeevents.Source{
|
||||
Component: "agent",
|
||||
Name: "main",
|
||||
},
|
||||
Scope: runtimeevents.Scope{
|
||||
AgentID: "main",
|
||||
SessionKey: "session-1",
|
||||
TurnID: "turn-1",
|
||||
},
|
||||
Payload: ToolExecStartPayload{
|
||||
Tool: "exec",
|
||||
Arguments: map[string]any{
|
||||
"secret": "should-not-be-logged-by-default",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if fields["event_id"] != "evt-test" || fields["source_component"] != "agent" {
|
||||
t.Fatalf("missing common event fields: %#v", fields)
|
||||
}
|
||||
if fields["tool"] != "exec" || fields["args_count"] != 1 {
|
||||
t.Fatalf("missing safe agent payload summary fields: %#v", fields)
|
||||
}
|
||||
if _, ok := fields["payload"]; ok {
|
||||
t.Fatalf("raw payload should not be included by runtimeEventLogFields: %#v", fields)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeEventLogFieldsIncludeSafeAttrs(t *testing.T) {
|
||||
fields := runtimeEventLogFields(runtimeevents.Event{
|
||||
ID: "evt-gateway",
|
||||
Kind: runtimeevents.KindGatewayReady,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
Attrs: map[string]any{
|
||||
"duration_ms": 42,
|
||||
"error": "startup failed",
|
||||
"event_kind": "conflict",
|
||||
},
|
||||
})
|
||||
|
||||
if fields["duration_ms"] != 42 || fields["error"] != "startup failed" {
|
||||
t.Fatalf("missing safe attrs: %#v", fields)
|
||||
}
|
||||
if fields["event_kind"] != runtimeevents.KindGatewayReady.String() {
|
||||
t.Fatalf("event_kind overwritten by attrs: %#v", fields)
|
||||
}
|
||||
if fields["attr_event_kind"] != "conflict" {
|
||||
t.Fatalf("conflicting attr not preserved with prefix: %#v", fields)
|
||||
}
|
||||
if _, ok := fields["payload"]; ok {
|
||||
t.Fatalf("raw payload should not be included by runtimeEventLogFields: %#v", fields)
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeEventLoggerStateForTest(
|
||||
al *AgentLoop,
|
||||
) (*runtimeEventLogger, runtimeevents.Subscription) {
|
||||
al.runtimeEventLogMu.RLock()
|
||||
defer al.runtimeEventLogMu.RUnlock()
|
||||
return al.runtimeEventLogger, al.runtimeEventLogSub
|
||||
}
|
||||
|
||||
func TestReloadProviderAndConfigRefreshesRuntimeEventLogger(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = t.TempDir()
|
||||
cfg.Events.Logging.Include = []string{"agent.*"}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
|
||||
defer al.Close()
|
||||
|
||||
eventLogger, logSub := runtimeEventLoggerStateForTest(al)
|
||||
if eventLogger == nil || logSub == nil {
|
||||
t.Fatal("expected initial runtime event logger subscription")
|
||||
}
|
||||
if eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindGatewayReloadCompleted,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
}) {
|
||||
t.Fatal("initial agent-only logging should not log gateway reload events")
|
||||
}
|
||||
|
||||
reloaded := config.DefaultConfig()
|
||||
reloaded.Agents.Defaults.Workspace = cfg.Agents.Defaults.Workspace
|
||||
reloaded.Events.Logging.Include = []string{"gateway.*"}
|
||||
if err := al.ReloadProviderAndConfig(context.Background(), &mockProvider{}, reloaded); err != nil {
|
||||
t.Fatalf("ReloadProviderAndConfig() error = %v", err)
|
||||
}
|
||||
|
||||
eventLogger, logSub = runtimeEventLoggerStateForTest(al)
|
||||
if eventLogger == nil || logSub == nil {
|
||||
t.Fatal("expected runtime event logger subscription after reload")
|
||||
}
|
||||
if !eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindGatewayReloadCompleted,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
}) {
|
||||
t.Fatal("reloaded gateway logging should log gateway reload events")
|
||||
}
|
||||
if eventLogger.shouldLog(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindAgentTurnStart,
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
}) {
|
||||
t.Fatal("reloaded gateway-only logging should not log agent events")
|
||||
}
|
||||
|
||||
disabled := config.DefaultConfig()
|
||||
disabled.Agents.Defaults.Workspace = cfg.Agents.Defaults.Workspace
|
||||
disabled.Events.Logging.Enabled = false
|
||||
if err := al.ReloadProviderAndConfig(context.Background(), &mockProvider{}, disabled); err != nil {
|
||||
t.Fatalf("ReloadProviderAndConfig() with disabled logging error = %v", err)
|
||||
}
|
||||
eventLogger, logSub = runtimeEventLoggerStateForTest(al)
|
||||
if eventLogger != nil || logSub != nil {
|
||||
t.Fatal("expected runtime event logger to be disabled after reload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseRuntimeEventLoggerSubscriptionWaitsForDrain(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var handled atomic.Uint64
|
||||
firstStarted := make(chan struct{})
|
||||
releaseFirst := make(chan struct{})
|
||||
sub, err := eventBus.Channel().Subscribe(
|
||||
context.Background(),
|
||||
runtimeevents.SubscribeOptions{
|
||||
Name: "runtime-event-logger",
|
||||
Buffer: 2,
|
||||
Concurrency: runtimeevents.Locked,
|
||||
},
|
||||
func(context.Context, runtimeevents.Event) error {
|
||||
if handled.Add(1) == 1 {
|
||||
close(firstStarted)
|
||||
<-releaseFirst
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
|
||||
first := eventBus.Publish(context.Background(), runtimeevents.Event{Kind: runtimeevents.Kind("test.first")})
|
||||
if first.Delivered != 1 {
|
||||
t.Fatalf("first Publish = %+v, want one delivered event", first)
|
||||
}
|
||||
select {
|
||||
case <-firstStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for first handler to start")
|
||||
}
|
||||
second := eventBus.Publish(context.Background(), runtimeevents.Event{Kind: runtimeevents.Kind("test.second")})
|
||||
if second.Delivered != 1 {
|
||||
t.Fatalf("second Publish = %+v, want one delivered event", second)
|
||||
}
|
||||
|
||||
closeReturned := make(chan struct{})
|
||||
go func() {
|
||||
closeRuntimeEventLoggerSubscription(sub)
|
||||
close(closeReturned)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-closeReturned:
|
||||
t.Fatal("runtime event logger close returned before buffered events drained")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseFirst)
|
||||
select {
|
||||
case <-closeReturned:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event logger close to return")
|
||||
}
|
||||
if got := handled.Load(); got != 2 {
|
||||
t.Fatalf("handled = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func subscribeRuntimeEventsForTest(
|
||||
t *testing.T,
|
||||
al *AgentLoop,
|
||||
buffer int,
|
||||
kinds ...runtimeevents.Kind,
|
||||
) (<-chan runtimeevents.Event, func()) {
|
||||
t.Helper()
|
||||
|
||||
if al == nil {
|
||||
t.Fatal("agent loop is nil")
|
||||
}
|
||||
channel := al.RuntimeEvents()
|
||||
if channel == nil {
|
||||
t.Fatal("runtime event channel is nil")
|
||||
}
|
||||
if len(kinds) > 0 {
|
||||
channel = channel.OfKind(kinds...)
|
||||
}
|
||||
sub, ch, err := channel.SubscribeChan(
|
||||
t.Context(),
|
||||
runtimeevents.SubscribeOptions{Name: "agent-runtime-test", Buffer: buffer},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
return ch, func() {
|
||||
if err := sub.Close(); err != nil {
|
||||
t.Errorf("runtime subscription close failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForRuntimeEvent(
|
||||
t *testing.T,
|
||||
ch <-chan runtimeevents.Event,
|
||||
timeout time.Duration,
|
||||
match func(runtimeevents.Event) bool,
|
||||
) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event stream closed before expected event arrived")
|
||||
}
|
||||
if match(evt) {
|
||||
return evt
|
||||
}
|
||||
case <-timer.C:
|
||||
t.Fatal("timed out waiting for expected runtime event")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectRuntimeEventStream(ch <-chan runtimeevents.Event) []runtimeevents.Event {
|
||||
var events []runtimeevents.Event
|
||||
for {
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
return events
|
||||
}
|
||||
events = append(events, evt)
|
||||
default:
|
||||
return events
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findRuntimeEvent(
|
||||
events []runtimeevents.Event,
|
||||
kind runtimeevents.Kind,
|
||||
) (runtimeevents.Event, bool) {
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
return evt, true
|
||||
}
|
||||
}
|
||||
return runtimeevents.Event{}, false
|
||||
}
|
||||
|
||||
func filterRuntimeEvents(events []runtimeevents.Event, kind runtimeevents.Kind) []runtimeevents.Event {
|
||||
var filtered []runtimeevents.Event
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
filtered = append(filtered, evt)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
+28
-4
@@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
@@ -155,6 +156,18 @@ func (sq *steeringQueue) lenScope(scope string) int {
|
||||
return len(sq.queues[normalizeSteeringScope(scope)])
|
||||
}
|
||||
|
||||
func (sq *steeringQueue) clearScope(scope string) int {
|
||||
sq.mu.Lock()
|
||||
defer sq.mu.Unlock()
|
||||
|
||||
scope = normalizeSteeringScope(scope)
|
||||
count := len(sq.queues[scope])
|
||||
if count > 0 {
|
||||
delete(sq.queues, scope)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// setMode updates the steering mode.
|
||||
func (sq *steeringQueue) setMode(mode SteeringMode) {
|
||||
sq.mu.Lock()
|
||||
@@ -206,7 +219,7 @@ func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers
|
||||
"scope": normalizeSteeringScope(scope),
|
||||
})
|
||||
|
||||
meta := EventMeta{
|
||||
meta := HookMeta{
|
||||
Source: "Steer",
|
||||
TracePath: "turn.interrupt.received",
|
||||
}
|
||||
@@ -230,7 +243,7 @@ func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
runtimeevents.KindAgentInterruptReceived,
|
||||
meta,
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindSteering,
|
||||
@@ -289,6 +302,13 @@ func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
|
||||
return al.steering.lenScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) clearSteeringMessagesForScope(scope string) int {
|
||||
if al.steering == nil {
|
||||
return 0
|
||||
}
|
||||
return al.steering.clearScope(scope)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) continueWithSteeringMessages(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
@@ -410,7 +430,7 @@ func (al *AgentLoop) InterruptGraceful(hint string) error {
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
runtimeevents.KindAgentInterruptReceived,
|
||||
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindGraceful,
|
||||
@@ -438,7 +458,7 @@ func (al *AgentLoop) InterruptHard() error {
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
runtimeevents.KindAgentInterruptReceived,
|
||||
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindHard,
|
||||
@@ -510,6 +530,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
"initial_history_length": ts.initialHistoryLength,
|
||||
})
|
||||
|
||||
// Cancel the active provider/tool turn contexts immediately so long-running
|
||||
// execution stops as soon as possible on the root turn.
|
||||
_ = ts.requestHardAbort()
|
||||
|
||||
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
|
||||
// from adding more messages to the session. This prevents race conditions
|
||||
// where rollback happens while children are still writing.
|
||||
|
||||
+354
-13
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -839,6 +840,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Run_PendingStopStillContinuesQueuedFollowUp(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
MaxParallelTurns: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &lateSteeringProvider{
|
||||
firstCallStarted: make(chan struct{}),
|
||||
releaseFirstCall: make(chan struct{}),
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
defer func() {
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}()
|
||||
|
||||
blockerSessionKey := session.BuildOpaqueSessionKey("agent:main:test:blocker")
|
||||
targetSessionKey := session.BuildOpaqueSessionKey("agent:main:test:target")
|
||||
blockerCtx := bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "blocker-chat",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
}
|
||||
targetCtx := bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "target-chat",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: blockerCtx,
|
||||
Content: "block worker pool",
|
||||
SessionKey: blockerSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(blocker) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstCallStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for blocker turn to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "skip this turn",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(target start) error = %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
ts := al.getActiveTurnState(targetSessionKey)
|
||||
if ts != nil && strings.HasPrefix(ts.turnID, pendingTurnPrefix) {
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for pending placeholder")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "/stop",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(/stop) error = %v", err)
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
stopSeen := false
|
||||
for !stopSeen {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.ChatID == "target-chat" && outbound.Content == "Task stopped. Current task was canceled." {
|
||||
stopSeen = true
|
||||
}
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for /stop reply")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: targetCtx,
|
||||
Content: "run this instead",
|
||||
SessionKey: targetSessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(follow-up) error = %v", err)
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
for al.pendingSteeringCountForScope(targetSessionKey) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for follow-up to enter scoped steering queue")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
close(provider.releaseFirstCall)
|
||||
|
||||
deadline = time.Now().Add(5 * time.Second)
|
||||
followUpSeen := false
|
||||
for !followUpSeen {
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.ChatID == "target-chat" && outbound.Content == "continued response" {
|
||||
followUpSeen = true
|
||||
}
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for queued follow-up continuation")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if al.GetActiveTurnBySession(targetSessionKey) == nil &&
|
||||
al.pendingSteeringCountForScope(targetSessionKey) == 0 {
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for target session to go idle")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls (blocker + continuation), got %d", calls)
|
||||
}
|
||||
|
||||
foundFollowUp := false
|
||||
for _, msg := range secondMessages {
|
||||
if msg.Role == "user" && msg.Content == "run this instead" {
|
||||
foundFollowUp = true
|
||||
}
|
||||
if msg.Role == "user" && msg.Content == "skip this turn" {
|
||||
t.Fatalf("unexpected canceled message in continuation context: %q", msg.Content)
|
||||
}
|
||||
}
|
||||
if !foundFollowUp {
|
||||
t.Fatal("expected queued follow-up to be processed after pending stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
@@ -1051,16 +1237,16 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
|
||||
foundResolvedMedia := false
|
||||
for _, msg := range msgs {
|
||||
if msg.Role != "user" || msg.Content != "describe this image" || len(msg.Media) != 1 {
|
||||
if msg.Role != "user" || !strings.Contains(msg.Content, "describe this image") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
|
||||
if strings.Contains(msg.Content, "[image:") {
|
||||
foundResolvedMedia = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundResolvedMedia {
|
||||
t.Fatal("expected continue path to inject steering media into the provider request")
|
||||
t.Fatal("expected continue path to inject image path tag into the provider request")
|
||||
}
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
@@ -1134,8 +1320,14 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
al.RegisterTool(tool2)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
32,
|
||||
runtimeevents.KindAgentInterruptReceived,
|
||||
runtimeevents.KindAgentTurnEnd,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
@@ -1222,8 +1414,8 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
interruptEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
@@ -1235,7 +1427,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
|
||||
}
|
||||
|
||||
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
||||
turnEndEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentTurnEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected turn end event")
|
||||
}
|
||||
@@ -1299,8 +1491,14 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentInterruptReceived,
|
||||
runtimeevents.KindAgentTurnEnd,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
@@ -1353,8 +1551,8 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
events := collectRuntimeEventStream(runtimeCh)
|
||||
interruptEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
@@ -1366,7 +1564,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
|
||||
}
|
||||
|
||||
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
||||
turnEndEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentTurnEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected turn end event")
|
||||
}
|
||||
@@ -1379,6 +1577,149 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_StopCommand_AbortsActiveTurnAndClearsQueuedSteering(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "cancel_tool",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "cancel_tool",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "should not continue",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
defer func() {
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}()
|
||||
|
||||
baseMsg := testInboundMessage(bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
SessionKey: sessionKey,
|
||||
})
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "do work",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(start) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for interruptible tool to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "follow up after cancel",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(follow-up) error = %v", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for al.pendingSteeringCountForScope(sessionKey) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for follow-up message to enter steering queue")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(context.Background(), bus.InboundMessage{
|
||||
Context: baseMsg.Context,
|
||||
Content: "/stop",
|
||||
SessionKey: sessionKey,
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishInbound(/stop) error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
want := "Task stopped. \"do work\" was canceled."
|
||||
if outbound.Content != want {
|
||||
t.Fatalf("stop reply = %q, want %q", outbound.Content, want)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /stop reply")
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(5 * time.Second)
|
||||
for al.GetActiveTurnBySession(sessionKey) != nil {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("timeout waiting for active turn to stop")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
if got := al.pendingSteeringCountForScope(sessionKey); got != 0 {
|
||||
t.Fatalf("expected cleared steering queue, got %d pending message(s)", got)
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("unexpected outbound after stop: %q", outbound.Content)
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
if calls != 1 {
|
||||
t.Fatalf("expected provider to stop before follow-up turn, got %d calls", calls)
|
||||
}
|
||||
}
|
||||
|
||||
// capturingMockProvider captures messages sent to Chat for inspection.
|
||||
type capturingMockProvider struct {
|
||||
response string
|
||||
|
||||
+35
-19
@@ -8,6 +8,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/messageutil"
|
||||
@@ -173,7 +174,10 @@ type SubTurnConfig struct {
|
||||
// Used by team tool to enforce token limits across all team members.
|
||||
InitialTokenBudget *atomic.Int64
|
||||
|
||||
// Can be extended with temperature, topP, etc.
|
||||
// TargetAgentID, when set, runs the sub-turn as the specified agent.
|
||||
// The target agent's workspace, model, tools, and system prompt are used
|
||||
// instead of the caller's. If empty, the sub-turn runs as the parent agent.
|
||||
TargetAgentID string
|
||||
}
|
||||
|
||||
// ====================== Context Keys ======================
|
||||
@@ -231,6 +235,7 @@ func (s *AgentLoopSpawner) SpawnSubTurn(
|
||||
Critical: cfg.Critical,
|
||||
Timeout: cfg.Timeout,
|
||||
MaxContextRunes: cfg.MaxContextRunes,
|
||||
TargetAgentID: cfg.TargetAgentID,
|
||||
}
|
||||
|
||||
return spawnSubTurn(ctx, s.al, parentTS, agentCfg)
|
||||
@@ -313,8 +318,9 @@ func spawnSubTurn(
|
||||
return nil, ErrDepthLimitExceeded
|
||||
}
|
||||
|
||||
// 2. Config validation
|
||||
if cfg.Model == "" {
|
||||
// 2. Config validation: Model is required unless TargetAgentID is set
|
||||
// (the target agent provides its own model).
|
||||
if cfg.Model == "" && cfg.TargetAgentID == "" {
|
||||
return nil, ErrInvalidSubTurnConfig
|
||||
}
|
||||
|
||||
@@ -332,12 +338,22 @@ func spawnSubTurn(
|
||||
|
||||
childID := al.generateSubTurnID()
|
||||
|
||||
// Get the agent instance from parent, falling back to the default agent.
|
||||
// Wrap it in a shallow copy that uses an ephemeral (in-memory only) session store
|
||||
// so that child turns never pollute or persist to the parent's session history.
|
||||
baseAgent := parentTS.agent
|
||||
if baseAgent == nil {
|
||||
baseAgent = al.registry.GetDefaultAgent()
|
||||
// Resolve the agent instance for the child turn.
|
||||
// When TargetAgentID is set, look up that agent from the registry so the
|
||||
// child runs with the target's workspace, model, tools, and system prompt.
|
||||
// Otherwise fall back to the parent's agent (existing behavior).
|
||||
var baseAgent *AgentInstance
|
||||
if cfg.TargetAgentID != "" {
|
||||
var ok bool
|
||||
baseAgent, ok = al.registry.GetAgent(cfg.TargetAgentID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("target agent %q not found in registry", cfg.TargetAgentID)
|
||||
}
|
||||
} else {
|
||||
baseAgent = parentTS.agent
|
||||
if baseAgent == nil {
|
||||
baseAgent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
}
|
||||
if baseAgent == nil {
|
||||
return nil, errors.New("parent turnState has no agent instance")
|
||||
@@ -422,7 +438,7 @@ func spawnSubTurn(
|
||||
parentTS.mu.Unlock()
|
||||
|
||||
// 6. Emit Spawn event
|
||||
al.emitEvent(EventKindSubTurnSpawn,
|
||||
al.emitEvent(runtimeevents.KindAgentSubTurnSpawn,
|
||||
childTS.eventMeta("spawnSubTurn", "subturn.spawn"),
|
||||
SubTurnSpawnPayload{
|
||||
AgentID: childTS.agentID,
|
||||
@@ -453,7 +469,7 @@ func spawnSubTurn(
|
||||
if err != nil {
|
||||
status = "error"
|
||||
}
|
||||
al.emitEvent(EventKindSubTurnEnd,
|
||||
al.emitEvent(runtimeevents.KindAgentSubTurnEnd,
|
||||
childTS.eventMeta("spawnSubTurn", "subturn.end"),
|
||||
SubTurnEndPayload{
|
||||
AgentID: childTS.agentID,
|
||||
@@ -504,16 +520,16 @@ func spawnSubTurn(
|
||||
//
|
||||
// Delivery behavior:
|
||||
// - If parent turn is still running: attempts to deliver to pendingResults channel
|
||||
// - If channel is full: emits SubTurnOrphanResultEvent (result is lost from channel but tracked)
|
||||
// - If parent turn has finished: emits SubTurnOrphanResultEvent (late arrival)
|
||||
// - If channel is full: emits agent.subturn.orphan (result is lost from channel but tracked)
|
||||
// - If parent turn has finished: emits agent.subturn.orphan (late arrival)
|
||||
//
|
||||
// Thread safety:
|
||||
// - Reads parent state under lock, then releases lock before channel send
|
||||
// - Small race window exists but is acceptable (worst case: result becomes orphan)
|
||||
//
|
||||
// Event emissions:
|
||||
// - SubTurnResultDeliveredEvent: successful delivery to channel
|
||||
// - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full)
|
||||
// - agent.subturn.result_delivered: successful delivery to channel
|
||||
// - agent.subturn.orphan: delivery failed (parent finished or channel full)
|
||||
func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, result *tools.ToolResult) {
|
||||
// Let GC clean up the pendingResults channel; parent Finish will no longer close it.
|
||||
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
|
||||
@@ -526,7 +542,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
|
||||
"recover": r,
|
||||
})
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(EventKindSubTurnOrphan,
|
||||
al.emitEvent(runtimeevents.KindAgentSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "panic"},
|
||||
)
|
||||
@@ -541,7 +557,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
|
||||
// If parent turn has already finished, treat this as an orphan result
|
||||
if isFinished || resultChan == nil {
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(EventKindSubTurnOrphan,
|
||||
al.emitEvent(runtimeevents.KindAgentSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "parent_finished"},
|
||||
)
|
||||
@@ -557,7 +573,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
|
||||
case resultChan <- result:
|
||||
// Successfully delivered
|
||||
if al != nil {
|
||||
al.emitEvent(EventKindSubTurnResultDelivered,
|
||||
al.emitEvent(runtimeevents.KindAgentSubTurnResultDelivered,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.result_delivered"),
|
||||
SubTurnResultDeliveredPayload{ContentLen: len(result.ForLLM)},
|
||||
)
|
||||
@@ -571,7 +587,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
|
||||
})
|
||||
if result != nil && al != nil {
|
||||
al.emitEvent(
|
||||
EventKindSubTurnOrphan,
|
||||
runtimeevents.KindAgentSubTurnOrphan,
|
||||
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
|
||||
SubTurnOrphanPayload{
|
||||
ParentTurnID: parentTS.turnID,
|
||||
|
||||
+258
-27
@@ -4,12 +4,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
@@ -22,30 +26,38 @@ const (
|
||||
// ====================== Test Helper: Event Collector ======================
|
||||
type eventCollector struct {
|
||||
mu sync.Mutex
|
||||
events []Event
|
||||
events []runtimeevents.Event
|
||||
}
|
||||
|
||||
func newEventCollector(t *testing.T, al *AgentLoop) (*eventCollector, func()) {
|
||||
t.Helper()
|
||||
c := &eventCollector{}
|
||||
sub := al.SubscribeEvents(16)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentSubTurnSpawn,
|
||||
runtimeevents.KindAgentSubTurnEnd,
|
||||
runtimeevents.KindAgentSubTurnResultDelivered,
|
||||
runtimeevents.KindAgentSubTurnOrphan,
|
||||
)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for evt := range sub.C {
|
||||
for evt := range runtimeCh {
|
||||
c.mu.Lock()
|
||||
c.events = append(c.events, evt)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
cleanup := func() {
|
||||
al.UnsubscribeEvents(sub.ID)
|
||||
closeRuntimeEvents()
|
||||
<-done
|
||||
}
|
||||
return c, cleanup
|
||||
}
|
||||
|
||||
func (c *eventCollector) hasEventOfKind(kind EventKind) bool {
|
||||
func (c *eventCollector) hasEventOfKind(kind runtimeevents.Kind) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for _, e := range c.events {
|
||||
@@ -131,7 +143,7 @@ func TestSpawnSubTurn(t *testing.T) {
|
||||
agent: al.registry.GetDefaultAgent(),
|
||||
}
|
||||
|
||||
// Subscribe to real EventBus to capture events
|
||||
// Subscribe to runtime events to capture sub-turn lifecycle.
|
||||
collector, collectCleanup := newEventCollector(t, al)
|
||||
defer collectCleanup()
|
||||
|
||||
@@ -158,12 +170,12 @@ func TestSpawnSubTurn(t *testing.T) {
|
||||
// Verify event emission
|
||||
time.Sleep(10 * time.Millisecond) // let event goroutine flush
|
||||
if tt.wantSpawn {
|
||||
if !collector.hasEventOfKind(EventKindSubTurnSpawn) {
|
||||
if !collector.hasEventOfKind(runtimeevents.KindAgentSubTurnSpawn) {
|
||||
t.Error("SubTurnSpawnEvent not emitted")
|
||||
}
|
||||
}
|
||||
if tt.wantEnd {
|
||||
if !collector.hasEventOfKind(EventKindSubTurnEnd) {
|
||||
if !collector.hasEventOfKind(runtimeevents.KindAgentSubTurnEnd) {
|
||||
t.Error("SubTurnEndEvent not emitted")
|
||||
}
|
||||
}
|
||||
@@ -316,8 +328,8 @@ func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // let event goroutine flush
|
||||
// Verify Orphan event is emitted
|
||||
if !collector.hasEventOfKind(EventKindSubTurnOrphan) {
|
||||
t.Error("SubTurnOrphanResultEvent not emitted for finished parent")
|
||||
if !collector.hasEventOfKind(runtimeevents.KindAgentSubTurnOrphan) {
|
||||
t.Error("agent.subturn.orphan not emitted for finished parent")
|
||||
}
|
||||
|
||||
// Verify history is NOT polluted
|
||||
@@ -591,12 +603,16 @@ func TestNestedSubTurnHierarchy(t *testing.T) {
|
||||
var spawnedTurns []turnInfo
|
||||
var mu sync.Mutex
|
||||
|
||||
// Subscribe to real EventBus to capture spawn events
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
16,
|
||||
runtimeevents.KindAgentSubTurnSpawn,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
go func() {
|
||||
for evt := range sub.C {
|
||||
if evt.Kind == EventKindSubTurnSpawn {
|
||||
for evt := range runtimeCh {
|
||||
if evt.Kind == runtimeevents.KindAgentSubTurnSpawn {
|
||||
p, _ := evt.Payload.(SubTurnSpawnPayload)
|
||||
mu.Lock()
|
||||
spawnedTurns = append(spawnedTurns, turnInfo{
|
||||
@@ -879,7 +895,7 @@ func TestSpawnSubTurn_PanicRecovery(t *testing.T) {
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // let event goroutine flush
|
||||
// SubTurnEndEvent should still be emitted
|
||||
if !collector.hasEventOfKind(EventKindSubTurnEnd) {
|
||||
if !collector.hasEventOfKind(runtimeevents.KindAgentSubTurnEnd) {
|
||||
t.Error("SubTurnEndEvent not emitted after panic")
|
||||
}
|
||||
|
||||
@@ -1229,18 +1245,23 @@ func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
|
||||
al, _, _, _, cleanup := newTestAgentLoop(t) //nolint:dogsled
|
||||
defer cleanup()
|
||||
|
||||
// Collect events via real EventBus
|
||||
var mu sync.Mutex
|
||||
var deliveredCount, orphanCount int
|
||||
sub := al.SubscribeEvents(64)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
64,
|
||||
runtimeevents.KindAgentSubTurnResultDelivered,
|
||||
runtimeevents.KindAgentSubTurnOrphan,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
go func() {
|
||||
for evt := range sub.C {
|
||||
for evt := range runtimeCh {
|
||||
mu.Lock()
|
||||
switch evt.Kind {
|
||||
case EventKindSubTurnResultDelivered:
|
||||
case runtimeevents.KindAgentSubTurnResultDelivered:
|
||||
deliveredCount++
|
||||
case EventKindSubTurnOrphan:
|
||||
case runtimeevents.KindAgentSubTurnOrphan:
|
||||
orphanCount++
|
||||
}
|
||||
mu.Unlock()
|
||||
@@ -1795,13 +1816,20 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
|
||||
provider := &slowMockProvider{delay: 5 * time.Second} // SubTurn takes 5 seconds
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Capture events via real EventBus
|
||||
var mu sync.Mutex
|
||||
var events []Event
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
var events []runtimeevents.Event
|
||||
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
||||
t,
|
||||
al,
|
||||
32,
|
||||
runtimeevents.KindAgentSubTurnSpawn,
|
||||
runtimeevents.KindAgentSubTurnEnd,
|
||||
runtimeevents.KindAgentSubTurnResultDelivered,
|
||||
runtimeevents.KindAgentSubTurnOrphan,
|
||||
)
|
||||
defer closeRuntimeEvents()
|
||||
go func() {
|
||||
for evt := range sub.C {
|
||||
for evt := range runtimeCh {
|
||||
mu.Lock()
|
||||
events = append(events, evt)
|
||||
mu.Unlock()
|
||||
@@ -2097,3 +2125,206 @@ func TestSubTurn_IndependentContext(t *testing.T) {
|
||||
t.Log("✓ SubTurn completed successfully (independent context)")
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TargetAgentID Tests ======================
|
||||
|
||||
// modelRecordingProvider captures the model passed to Chat for test assertions.
|
||||
type modelRecordingProvider struct {
|
||||
mu sync.Mutex
|
||||
lastModel string
|
||||
}
|
||||
|
||||
func (rp *modelRecordingProvider) Chat(
|
||||
_ context.Context,
|
||||
_ []providers.Message,
|
||||
_ []providers.ToolDefinition,
|
||||
model string,
|
||||
_ map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
rp.mu.Lock()
|
||||
rp.lastModel = model
|
||||
rp.mu.Unlock()
|
||||
return &providers.LLMResponse{Content: "Mock response"}, nil
|
||||
}
|
||||
|
||||
func (rp *modelRecordingProvider) GetDefaultModel() string { return "mock-model" }
|
||||
|
||||
func (rp *modelRecordingProvider) getLastModel() string {
|
||||
rp.mu.Lock()
|
||||
defer rp.mu.Unlock()
|
||||
return rp.lastModel
|
||||
}
|
||||
|
||||
// newMultiAgentLoop creates an AgentLoop with two named agents for testing
|
||||
// cross-agent delegation via TargetAgentID.
|
||||
func newMultiAgentLoop(t *testing.T, provider providers.LLMProvider) (*AgentLoop, func()) {
|
||||
t.Helper()
|
||||
tmpDir, err := os.MkdirTemp("", "multiagent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp dir: %v", err)
|
||||
}
|
||||
|
||||
alphaDir := filepath.Join(tmpDir, "alpha")
|
||||
betaDir := filepath.Join(tmpDir, "beta")
|
||||
os.MkdirAll(alphaDir, 0o755)
|
||||
os.MkdirAll(betaDir, 0o755)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "default-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
List: []config.AgentConfig{
|
||||
{
|
||||
ID: "alpha",
|
||||
Workspace: alphaDir,
|
||||
Model: &config.AgentModelConfig{Primary: "model-alpha"},
|
||||
},
|
||||
{
|
||||
ID: "beta",
|
||||
Workspace: betaDir,
|
||||
Model: &config.AgentModelConfig{Primary: "model-beta"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
return al, func() { os.RemoveAll(tmpDir) }
|
||||
}
|
||||
|
||||
func TestSpawnSubTurn_TargetAgentID_UsesTargetAgent(t *testing.T) {
|
||||
rp := &modelRecordingProvider{}
|
||||
al, cleanup := newMultiAgentLoop(t, rp)
|
||||
defer cleanup()
|
||||
|
||||
alphaAgent, ok := al.registry.GetAgent("alpha")
|
||||
if !ok {
|
||||
t.Fatal("alpha agent not in registry")
|
||||
}
|
||||
|
||||
// Parent is alpha, target is beta
|
||||
parent := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: "parent-alpha",
|
||||
depth: 0,
|
||||
childTurnIDs: []string{},
|
||||
pendingResults: make(chan *tools.ToolResult, 4),
|
||||
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
|
||||
session: &ephemeralSessionStore{},
|
||||
agent: alphaAgent,
|
||||
}
|
||||
|
||||
result, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
|
||||
TargetAgentID: "beta",
|
||||
SystemPrompt: "task for beta",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("spawnSubTurn failed: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// The recording provider captures the model passed to Chat().
|
||||
// If TargetAgentID works correctly, the child turn should have
|
||||
// used beta's model, not alpha's.
|
||||
if got := rp.getLastModel(); got != "model-beta" {
|
||||
t.Errorf("child turn used model %q, want %q", got, "model-beta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnSubTurn_TargetAgentID_NotFound(t *testing.T) {
|
||||
al, cleanup := newMultiAgentLoop(t, &mockProvider{})
|
||||
defer cleanup()
|
||||
|
||||
alphaAgent, _ := al.registry.GetAgent("alpha")
|
||||
parent := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: "parent-alpha",
|
||||
depth: 0,
|
||||
childTurnIDs: []string{},
|
||||
pendingResults: make(chan *tools.ToolResult, 4),
|
||||
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
|
||||
session: &ephemeralSessionStore{},
|
||||
agent: alphaAgent,
|
||||
}
|
||||
|
||||
_, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
|
||||
TargetAgentID: "nonexistent",
|
||||
SystemPrompt: "task",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent agent")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("error should mention 'not found', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnSubTurn_TargetAgentID_EmptyModelAccepted(t *testing.T) {
|
||||
al, cleanup := newMultiAgentLoop(t, &mockProvider{})
|
||||
defer cleanup()
|
||||
|
||||
alphaAgent, _ := al.registry.GetAgent("alpha")
|
||||
parent := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: "parent-alpha",
|
||||
depth: 0,
|
||||
childTurnIDs: []string{},
|
||||
pendingResults: make(chan *tools.ToolResult, 4),
|
||||
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
|
||||
session: &ephemeralSessionStore{},
|
||||
agent: alphaAgent,
|
||||
}
|
||||
|
||||
// Model is empty but TargetAgentID is set — should NOT fail validation
|
||||
result, err := spawnSubTurn(context.Background(), al, parent, SubTurnConfig{
|
||||
Model: "", // intentionally empty
|
||||
TargetAgentID: "beta",
|
||||
SystemPrompt: "task for beta",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("should accept empty Model when TargetAgentID is set, got: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelegateToolNotRegistered_SingleAgent(t *testing.T) {
|
||||
// Single-agent setup: delegate should not be registered
|
||||
al, _, _, provider, cleanup := newTestAgentLoop(t)
|
||||
_ = provider
|
||||
defer cleanup()
|
||||
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("default agent should exist")
|
||||
}
|
||||
if _, has := agent.Tools.Get("delegate"); has {
|
||||
t.Error("delegate tool should not be registered in single-agent setup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelegateToolRegistered_MultiAgent(t *testing.T) {
|
||||
al, cleanup := newMultiAgentLoop(t, &mockProvider{})
|
||||
defer cleanup()
|
||||
|
||||
// Both agents should have the delegate tool
|
||||
for _, id := range []string{"alpha", "beta"} {
|
||||
agent, ok := al.registry.GetAgent(id)
|
||||
if !ok {
|
||||
t.Fatalf("agent %q not found", id)
|
||||
}
|
||||
if _, has := agent.Tools.Get("delegate"); !has {
|
||||
t.Errorf("agent %q should have delegate tool in multi-agent setup", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func cloneStringMap(src map[string]string) map[string]string {
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneEventMeta(meta EventMeta) EventMeta {
|
||||
func cloneHookMeta(meta HookMeta) HookMeta {
|
||||
meta.turnContext = cloneTurnContext(meta.turnContext)
|
||||
return meta
|
||||
}
|
||||
|
||||
+18
-8
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
@@ -25,10 +26,14 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
|
||||
al.registerActiveTurn(ts)
|
||||
defer al.clearActiveTurn(ts)
|
||||
|
||||
if al.takePendingStop(ts.sessionKey) {
|
||||
_ = ts.requestHardAbort()
|
||||
}
|
||||
|
||||
turnStatus := TurnEndStatusCompleted
|
||||
defer func() {
|
||||
al.emitEvent(
|
||||
EventKindTurnEnd,
|
||||
runtimeevents.KindAgentTurnEnd,
|
||||
ts.eventMeta("runTurn", "turn.end"),
|
||||
TurnEndPayload{
|
||||
Status: turnStatus,
|
||||
@@ -39,8 +44,13 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
|
||||
)
|
||||
}()
|
||||
|
||||
if ts.hardAbortRequested() {
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindTurnStart,
|
||||
runtimeevents.KindAgentTurnStart,
|
||||
ts.eventMeta("runTurn", "turn.start"),
|
||||
TurnStartPayload{
|
||||
UserMessage: ts.userMessage,
|
||||
@@ -140,7 +150,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState, pipeline *Pipel
|
||||
})
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindSteeringInjected,
|
||||
runtimeevents.KindAgentSteeringInjected,
|
||||
ts.eventMeta("runTurn", "turn.steering.injected"),
|
||||
SteeringInjectedPayload{
|
||||
Count: len(pendingMessages),
|
||||
@@ -249,7 +259,7 @@ func (al *AgentLoop) abortTurn(ts *turnState) (turnResult, error) {
|
||||
if !ts.opts.NoHistory {
|
||||
if err := ts.restoreSession(ts.agent); err != nil {
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
runtimeevents.KindAgentError,
|
||||
ts.eventMeta("abortTurn", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "session_restore",
|
||||
@@ -414,7 +424,7 @@ func (al *AgentLoop) askSideQuestion(
|
||||
llmModel := activeModel
|
||||
if al.hooks != nil {
|
||||
llmReq, decision := al.hooks.BeforeLLM(ctx, &LLMHookRequest{
|
||||
Meta: EventMeta{
|
||||
Meta: HookMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.request",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
@@ -494,8 +504,8 @@ func (al *AgentLoop) askSideQuestion(
|
||||
resp, err = callSideLLM(messages)
|
||||
if err != nil && hasMediaRefs(messages) && isVisionUnsupportedError(err) {
|
||||
al.emitEvent(
|
||||
EventKindLLMRetry,
|
||||
EventMeta{
|
||||
runtimeevents.KindAgentLLMRetry,
|
||||
HookMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.retry",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
@@ -521,7 +531,7 @@ func (al *AgentLoop) askSideQuestion(
|
||||
// Apply after_llm hooks
|
||||
if al.hooks != nil {
|
||||
llmResp, decision := al.hooks.AfterLLM(ctx, &LLMHookResponse{
|
||||
Meta: EventMeta{
|
||||
Meta: HookMeta{
|
||||
Source: "askSideQuestion",
|
||||
TracePath: "turn.llm.response",
|
||||
turnContext: cloneTurnContext(turnCtx),
|
||||
|
||||
@@ -135,6 +135,16 @@ func (p *errorProvider) Chat(
|
||||
return nil, errors.New("context_length_exceeded")
|
||||
case "vision":
|
||||
return nil, errors.New("vision_unsupported")
|
||||
case "connection_reset":
|
||||
return nil, errors.New("connection reset by peer")
|
||||
case "broken_pipe":
|
||||
return nil, errors.New("broken pipe")
|
||||
case "read_tcp":
|
||||
return nil, errors.New("read tcp 127.0.0.1:8080: connection reset")
|
||||
case "eof":
|
||||
return nil, errors.New("EOF")
|
||||
case "connection_refused":
|
||||
return nil, errors.New("connection refused")
|
||||
default:
|
||||
return nil, errors.New("unknown error")
|
||||
}
|
||||
@@ -366,6 +376,163 @@ func TestPipeline_CallLLM_ContextLengthError(t *testing.T) {
|
||||
t.Logf("CallLLM result after context error: err=%v", err)
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_NetworkErrorRetry(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
errType string
|
||||
}{
|
||||
{"connection_reset", "connection_reset"},
|
||||
{"broken_pipe", "broken_pipe"},
|
||||
{"read_tcp", "read_tcp"},
|
||||
{"eof", "eof"},
|
||||
{"connection_refused", "connection_refused"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
errorPrv := &errorProvider{errType: tc.errType}
|
||||
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
|
||||
defer cleanup()
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err == nil {
|
||||
t.Error("expected error after network error retries")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_RetryConfigRespected(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
MaxLLMRetries: 3,
|
||||
LLMRetryBackoffSecs: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &errorProvider{errType: "connection_reset"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
defer al.Close()
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error after retries")
|
||||
}
|
||||
|
||||
expectedMinTime := 3 * time.Second
|
||||
if elapsed < expectedMinTime {
|
||||
t.Errorf("expected at least %v of backoff, got %v", expectedMinTime, elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeline_CallLLM_RetryCountLimit(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
counterPrv := &countingErrorProvider{errType: "connection_reset", targetCalls: 5}
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
MaxLLMRetries: 2,
|
||||
LLMRetryBackoffSecs: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, counterPrv)
|
||||
defer al.Close()
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
pipeline := NewPipeline(al)
|
||||
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
||||
turnID: "turn-1",
|
||||
context: newTurnContext(nil, nil, nil),
|
||||
})
|
||||
|
||||
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
||||
if err != nil {
|
||||
t.Fatalf("SetupTurn failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
||||
if err == nil {
|
||||
t.Error("expected error after retries")
|
||||
}
|
||||
|
||||
if counterPrv.callCount != 3 {
|
||||
t.Errorf("expected exactly 3 calls (1 initial + 2 retries), got %d", counterPrv.callCount)
|
||||
}
|
||||
}
|
||||
|
||||
type countingErrorProvider struct {
|
||||
errType string
|
||||
targetCalls int
|
||||
callCount int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (p *countingErrorProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.callCount++
|
||||
p.mu.Unlock()
|
||||
return nil, errors.New("connection reset by peer")
|
||||
}
|
||||
|
||||
func (p *countingErrorProvider) GetDefaultModel() string {
|
||||
return "counting-error-model"
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Pipeline Method Tests: ExecuteTools
|
||||
// =============================================================================
|
||||
|
||||
@@ -256,7 +256,10 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop
|
||||
// Bind session store and capture initial history length for rollback logic
|
||||
if agent != nil && agent.Sessions != nil {
|
||||
ts.session = agent.Sessions
|
||||
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey))
|
||||
history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey)
|
||||
ts.initialHistoryLength = len(history)
|
||||
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
||||
ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey)
|
||||
}
|
||||
|
||||
return ts
|
||||
@@ -442,9 +445,9 @@ func (ts *turnState) hardAbortRequested() bool {
|
||||
return ts.hardAbort
|
||||
}
|
||||
|
||||
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
|
||||
func (ts *turnState) eventMeta(source, tracePath string) HookMeta {
|
||||
snap := ts.snapshot()
|
||||
return EventMeta{
|
||||
return HookMeta{
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
|
||||
@@ -82,7 +82,8 @@ Notes:
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "elevenlabs-asr",
|
||||
"model": "elevenlabs/scribe_v1"
|
||||
"provider": "elevenlabs",
|
||||
"model": "scribe_v1"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -130,7 +131,7 @@ PicoClaw currently supports three main ASR routes:
|
||||
|
||||
| Route | Example models | Behavior |
|
||||
| --- | --- | --- |
|
||||
| ElevenLabs ASR | `elevenlabs/scribe_v1` | Uses the ElevenLabs transcription API. |
|
||||
| ElevenLabs ASR | `provider: elevenlabs`, `model: scribe_v1` | Uses the ElevenLabs transcription API. |
|
||||
| Whisper endpoint models | `openai/whisper-1`, `groq/whisper-large-v3` | Uses an OpenAI-compatible `/audio/transcriptions` endpoint. |
|
||||
| Audio-capable chat models **(Under construction)** | `openai/gpt-4o-audio-preview`, `gemini/gemini-2.5-flash` | Sends audio to a multimodal chat model and asks it to transcribe. |
|
||||
|
||||
@@ -142,7 +143,7 @@ If you are unsure which one to pick, choose Groq Whisper or ElevenLabs first.
|
||||
|
||||
1. **Preferred path**: resolve `voice.model_name` against `model_list`.
|
||||
2. If that resolved model is:
|
||||
- `elevenlabs/...`, PicoClaw uses the ElevenLabs transcriber.
|
||||
- an `elevenlabs` provider model, PicoClaw uses the ElevenLabs transcriber.
|
||||
- an OpenAI-compatible Whisper model, PicoClaw uses the Whisper transcriber.
|
||||
- an audio-capable chat model, PicoClaw uses `AudioModelTranscriber`.
|
||||
3. **Fallback path**: if `voice.model_name` is not set, PicoClaw performs a compatibility scan through `model_list` for legacy auto-detected ASR entries.
|
||||
|
||||
@@ -82,7 +82,8 @@ model_list:
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "elevenlabs-asr",
|
||||
"model": "elevenlabs/scribe_v1"
|
||||
"provider": "elevenlabs",
|
||||
"model": "scribe_v1"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -130,7 +131,7 @@ PicoClaw 目前主要支持三种 ASR 路径:
|
||||
|
||||
| 路径 | 示例模型 | 行为说明 |
|
||||
| --- | --- | --- |
|
||||
| ElevenLabs ASR | `elevenlabs/scribe_v1` | 使用 ElevenLabs 的语音转录接口。 |
|
||||
| ElevenLabs ASR | `provider: elevenlabs`,`model: scribe_v1` | 使用 ElevenLabs 的语音转录接口。 |
|
||||
| Whisper 接口模型 | `openai/whisper-1`、`groq/whisper-large-v3` | 使用 OpenAI 兼容的 `/audio/transcriptions` 接口。 |
|
||||
| 支持音频的聊天模型 **(重构中)** | `openai/gpt-4o-audio-preview`、`gemini/gemini-2.5-flash` | 把音频发给多模态聊天模型,并要求它返回转录结果。 |
|
||||
|
||||
@@ -142,7 +143,7 @@ PicoClaw 目前主要支持三种 ASR 路径:
|
||||
|
||||
1. **首选路径**:根据 `voice.model_name` 在 `model_list` 中找到对应模型。
|
||||
2. 如果找到的模型属于以下类型:
|
||||
- `elevenlabs/...`,则使用 ElevenLabs transcriber。
|
||||
- `provider=elevenlabs` 的模型,则使用 ElevenLabs transcriber。
|
||||
- OpenAI 兼容的 Whisper 模型,则使用 Whisper transcriber。
|
||||
- 支持音频输入的聊天模型,则使用 `AudioModelTranscriber`。
|
||||
3. **回退路径**:如果没有设置 `voice.model_name`,PicoClaw 会为了兼容旧配置,扫描 `model_list` 中可自动识别的 ASR 条目。
|
||||
|
||||
+21
-6
@@ -8,6 +8,12 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
const elevenLabsSupportedModelID = "scribe_v1"
|
||||
|
||||
func ElevenLabsSupportedModelID() string {
|
||||
return elevenLabsSupportedModelID
|
||||
}
|
||||
|
||||
type Transcriber interface {
|
||||
Name() string
|
||||
Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
|
||||
@@ -72,14 +78,23 @@ func whisperModelID(modelCfg *config.ModelConfig) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func isElevenLabsTranscriptionModel(modelCfg *config.ModelConfig) bool {
|
||||
if modelCfg == nil || modelCfg.APIKey() == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
protocol, _ := providers.ExtractProtocol(modelCfg)
|
||||
return protocol == "elevenlabs"
|
||||
}
|
||||
|
||||
func transcriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber {
|
||||
if modelCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
protocol, _ := providers.ExtractProtocol(modelCfg)
|
||||
if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
|
||||
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
|
||||
if isElevenLabsTranscriptionModel(modelCfg) {
|
||||
_, modelID := providers.ExtractProtocol(modelCfg)
|
||||
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase, modelID)
|
||||
}
|
||||
if modelID := whisperModelID(modelCfg); modelID != "" {
|
||||
return NewWhisperTranscriber(modelCfg)
|
||||
@@ -95,9 +110,9 @@ func fallbackTranscriberFromModelConfig(modelCfg *config.ModelConfig) Transcribe
|
||||
return nil
|
||||
}
|
||||
|
||||
protocol, _ := providers.ExtractProtocol(modelCfg)
|
||||
if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
|
||||
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
|
||||
if isElevenLabsTranscriptionModel(modelCfg) {
|
||||
_, modelID := providers.ExtractProtocol(modelCfg)
|
||||
return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase, modelID)
|
||||
}
|
||||
if modelID := whisperModelID(modelCfg); modelID != "" {
|
||||
return NewWhisperTranscriber(modelCfg)
|
||||
|
||||
@@ -46,6 +46,21 @@ func TestDetectTranscriber(t *testing.T) {
|
||||
},
|
||||
wantName: "elevenlabs",
|
||||
},
|
||||
{
|
||||
name: "explicit elevenlabs provider selects elevenlabs transcriber",
|
||||
cfg: &config.Config{
|
||||
Voice: config.VoiceConfig{ModelName: "my-asr-model"},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "my-asr-model",
|
||||
Provider: "elevenlabs",
|
||||
Model: "scribe_v1",
|
||||
APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantName: "elevenlabs",
|
||||
},
|
||||
{
|
||||
name: "voice model name alias selects whisper transcriber for groq",
|
||||
cfg: &config.Config{
|
||||
|
||||
@@ -20,19 +20,24 @@ import (
|
||||
type ElevenLabsTranscriber struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
modelID string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewElevenLabsTranscriber(apiKey, apiBase string) *ElevenLabsTranscriber {
|
||||
func NewElevenLabsTranscriber(apiKey, apiBase, modelID string) *ElevenLabsTranscriber {
|
||||
logger.DebugCF("voice", "Creating ElevenLabs transcriber", map[string]any{"has_api_key": apiKey != ""})
|
||||
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.elevenlabs.io"
|
||||
}
|
||||
if modelID == "" || modelID != ElevenLabsSupportedModelID() {
|
||||
modelID = ElevenLabsSupportedModelID()
|
||||
}
|
||||
|
||||
return &ElevenLabsTranscriber{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
modelID: modelID,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
@@ -74,7 +79,7 @@ func (t *ElevenLabsTranscriber) Transcribe(ctx context.Context, audioFilePath st
|
||||
return nil, fmt.Errorf("failed to copy file content: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.WriteField("model_id", "scribe_v1"); err != nil {
|
||||
if err = writer.WriteField("model_id", t.modelID); err != nil {
|
||||
return nil, fmt.Errorf("failed to write model_id field: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,10 +3,14 @@ package asr
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -14,7 +18,7 @@ import (
|
||||
var _ Transcriber = (*ElevenLabsTranscriber)(nil)
|
||||
|
||||
func TestElevenLabsTranscriberName(t *testing.T) {
|
||||
tr := NewElevenLabsTranscriber("sk_test", "")
|
||||
tr := NewElevenLabsTranscriber("sk_test", "", "scribe_v1")
|
||||
if got := tr.Name(); got != "elevenlabs" {
|
||||
t.Errorf("Name() = %q, want %q", got, "elevenlabs")
|
||||
}
|
||||
@@ -35,6 +39,35 @@ func TestElevenLabsTranscribe(t *testing.T) {
|
||||
if r.Header.Get("Xi-Api-Key") != "sk_test" {
|
||||
t.Errorf("unexpected xi-api-key header: %s", r.Header.Get("Xi-Api-Key"))
|
||||
}
|
||||
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseMediaType() error = %v", err)
|
||||
}
|
||||
if mediaType != "multipart/form-data" {
|
||||
t.Fatalf("content-type = %q, want multipart/form-data", mediaType)
|
||||
}
|
||||
reader := multipart.NewReader(r.Body, params["boundary"])
|
||||
var gotModelID string
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NextPart() error = %v", err)
|
||||
}
|
||||
if part.FormName() != "model_id" {
|
||||
continue
|
||||
}
|
||||
body, err := io.ReadAll(part)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll(part) error = %v", err)
|
||||
}
|
||||
gotModelID = strings.TrimSpace(string(body))
|
||||
}
|
||||
if gotModelID != "scribe_v1" {
|
||||
t.Fatalf("model_id = %q, want %q", gotModelID, "scribe_v1")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TranscriptionResponse{
|
||||
Text: "hello from elevenlabs",
|
||||
@@ -43,7 +76,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewElevenLabsTranscriber("sk_test", "")
|
||||
tr := NewElevenLabsTranscriber("sk_test", "", "scribe_v1")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
resp, err := tr.Transcribe(context.Background(), audioPath)
|
||||
@@ -64,7 +97,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewElevenLabsTranscriber("sk_bad", "")
|
||||
tr := NewElevenLabsTranscriber("sk_bad", "", "scribe_v1")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
_, err := tr.Transcribe(context.Background(), audioPath)
|
||||
@@ -74,10 +107,54 @@ func TestElevenLabsTranscribe(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("missing file", func(t *testing.T) {
|
||||
tr := NewElevenLabsTranscriber("sk_test", "")
|
||||
tr := NewElevenLabsTranscriber("sk_test", "", "scribe_v1")
|
||||
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported model falls back to scribe_v1", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseMediaType() error = %v", err)
|
||||
}
|
||||
if mediaType != "multipart/form-data" {
|
||||
t.Fatalf("content-type = %q, want multipart/form-data", mediaType)
|
||||
}
|
||||
reader := multipart.NewReader(r.Body, params["boundary"])
|
||||
var gotModelID string
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NextPart() error = %v", err)
|
||||
}
|
||||
if part.FormName() != "model_id" {
|
||||
continue
|
||||
}
|
||||
body, err := io.ReadAll(part)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll(part) error = %v", err)
|
||||
}
|
||||
gotModelID = strings.TrimSpace(string(body))
|
||||
}
|
||||
if gotModelID != "scribe_v1" {
|
||||
t.Fatalf("model_id = %q, want runtime fallback to %q", gotModelID, "scribe_v1")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TranscriptionResponse{Text: "ok"})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tr := NewElevenLabsTranscriber("sk_test", "", "unsupported-model")
|
||||
tr.apiBase = srv.URL
|
||||
|
||||
if _, err := tr.Transcribe(context.Background(), audioPath); err != nil {
|
||||
t.Fatalf("Transcribe() error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+44
-5
@@ -6,6 +6,7 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -48,6 +49,13 @@ type MessageBus struct {
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
streamDelegate atomic.Value // stores StreamDelegate
|
||||
eventPublisher atomic.Value // stores EventPublisher
|
||||
}
|
||||
|
||||
// EventPublisher is the minimal runtime event publisher used by MessageBus.
|
||||
type EventPublisher interface {
|
||||
Publish(ctx context.Context, evt runtimeevents.Event) runtimeevents.PublishResult
|
||||
PublishNonBlocking(evt runtimeevents.Event) runtimeevents.PublishResult
|
||||
}
|
||||
|
||||
func NewMessageBus() *MessageBus {
|
||||
@@ -92,9 +100,14 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error
|
||||
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
|
||||
msg = NormalizeInboundMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
mb.publishFailure("inbound", runtimeScopeFromInboundContext(msg.Context), ErrMissingInboundContext)
|
||||
return ErrMissingInboundContext
|
||||
}
|
||||
return publish(ctx, mb, mb.inbound, msg)
|
||||
if err := publish(ctx, mb, mb.inbound, msg); err != nil {
|
||||
mb.publishFailure("inbound", runtimeScopeFromInboundContext(msg.Context), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) InboundChan() <-chan InboundMessage {
|
||||
@@ -104,9 +117,14 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage {
|
||||
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
|
||||
msg = NormalizeOutboundMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
mb.publishFailure("outbound", runtimeScopeFromInboundContext(msg.Context), ErrMissingOutboundContext)
|
||||
return ErrMissingOutboundContext
|
||||
}
|
||||
return publish(ctx, mb, mb.outbound, msg)
|
||||
if err := publish(ctx, mb, mb.outbound, msg); err != nil {
|
||||
mb.publishFailure("outbound", runtimeScopeFromInboundContext(msg.Context), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
|
||||
@@ -116,9 +134,14 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
|
||||
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
|
||||
msg = NormalizeOutboundMediaMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
mb.publishFailure("outbound_media", runtimeScopeFromInboundContext(msg.Context), ErrMissingOutboundMediaContext)
|
||||
return ErrMissingOutboundMediaContext
|
||||
}
|
||||
return publish(ctx, mb, mb.outboundMedia, msg)
|
||||
if err := publish(ctx, mb, mb.outboundMedia, msg); err != nil {
|
||||
mb.publishFailure("outbound_media", runtimeScopeFromInboundContext(msg.Context), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
|
||||
@@ -126,7 +149,11 @@ func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishAudioChunk(ctx context.Context, chunk AudioChunk) error {
|
||||
return publish(ctx, mb, mb.audioChunks, chunk)
|
||||
if err := publish(ctx, mb, mb.audioChunks, chunk); err != nil {
|
||||
mb.publishFailure("audio_chunk", runtimeScopeFromAudioChunk(chunk), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk {
|
||||
@@ -134,7 +161,11 @@ func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishVoiceControl(ctx context.Context, ctrl VoiceControl) error {
|
||||
return publish(ctx, mb, mb.voiceControls, ctrl)
|
||||
if err := publish(ctx, mb, mb.voiceControls, ctrl); err != nil {
|
||||
mb.publishFailure("voice_control", runtimeScopeFromVoiceControl(ctrl), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) VoiceControlsChan() <-chan VoiceControl {
|
||||
@@ -146,6 +177,11 @@ func (mb *MessageBus) SetStreamDelegate(d StreamDelegate) {
|
||||
mb.streamDelegate.Store(d)
|
||||
}
|
||||
|
||||
// SetEventPublisher registers a runtime event publisher for bus errors and lifecycle events.
|
||||
func (mb *MessageBus) SetEventPublisher(p EventPublisher) {
|
||||
mb.eventPublisher.Store(p)
|
||||
}
|
||||
|
||||
// GetStreamer returns a Streamer for the given channel+chatID via the delegate.
|
||||
func (mb *MessageBus) GetStreamer(ctx context.Context, channel, chatID string) (Streamer, bool) {
|
||||
if d, ok := mb.streamDelegate.Load().(StreamDelegate); ok && d != nil {
|
||||
@@ -156,6 +192,7 @@ func (mb *MessageBus) GetStreamer(ctx context.Context, channel, chatID string) (
|
||||
|
||||
func (mb *MessageBus) Close() {
|
||||
mb.closeOnce.Do(func() {
|
||||
mb.publishCloseEvent(runtimeevents.KindBusCloseStarted, 0)
|
||||
// notify all blocked publishers to exit
|
||||
close(mb.done)
|
||||
|
||||
@@ -195,6 +232,8 @@ func (mb *MessageBus) Close() {
|
||||
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
|
||||
"count": drained,
|
||||
})
|
||||
mb.publishCloseEvent(runtimeevents.KindBusCloseDrained, drained)
|
||||
}
|
||||
mb.publishCloseEvent(runtimeevents.KindBusCloseCompleted, drained)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func TestPublishConsume(t *testing.T) {
|
||||
@@ -171,6 +173,86 @@ func TestPublishInbound_BackfillsContextFromLegacyFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBusPublishesRuntimeFailureAndCloseEvents(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().OfKind(
|
||||
runtimeevents.KindBusPublishFailed,
|
||||
runtimeevents.KindBusCloseStarted,
|
||||
runtimeevents.KindBusCloseDrained,
|
||||
runtimeevents.KindBusCloseCompleted,
|
||||
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "bus-events", Buffer: 4})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
mb := NewMessageBus()
|
||||
mb.SetEventPublisher(eventBus)
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{}); err == nil {
|
||||
t.Fatal("expected PublishInbound to fail")
|
||||
}
|
||||
failed := receiveBusRuntimeEvent(t, eventsCh)
|
||||
if failed.Kind != runtimeevents.KindBusPublishFailed ||
|
||||
failed.Source.Name != "inbound" ||
|
||||
failed.Severity != runtimeevents.SeverityError {
|
||||
t.Fatalf("publish failed event = %+v", failed)
|
||||
}
|
||||
if failed.Attrs["stream"] != "inbound" || failed.Attrs["error"] == "" {
|
||||
t.Fatalf("publish failed attrs = %#v, want stream and error", failed.Attrs)
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(context.Background(), OutboundMessage{
|
||||
Context: NewOutboundContext("telegram", "chat-1", ""),
|
||||
Content: "queued",
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
mb.Close()
|
||||
|
||||
seen := map[runtimeevents.Kind]bool{}
|
||||
var drainedAttrs map[string]any
|
||||
for range 3 {
|
||||
evt := receiveBusRuntimeEvent(t, eventsCh)
|
||||
seen[evt.Kind] = true
|
||||
if evt.Kind == runtimeevents.KindBusCloseDrained {
|
||||
drainedAttrs = evt.Attrs
|
||||
}
|
||||
}
|
||||
for _, kind := range []runtimeevents.Kind{
|
||||
runtimeevents.KindBusCloseStarted,
|
||||
runtimeevents.KindBusCloseDrained,
|
||||
runtimeevents.KindBusCloseCompleted,
|
||||
} {
|
||||
if !seen[kind] {
|
||||
t.Fatalf("missing %s event, seen=%v", kind, seen)
|
||||
}
|
||||
}
|
||||
if drainedAttrs["drained"] != 1 {
|
||||
t.Fatalf("bus close drained attrs = %#v, want drained count", drainedAttrs)
|
||||
}
|
||||
}
|
||||
|
||||
func receiveBusRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event channel closed before expected event")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package bus
|
||||
|
||||
import (
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
type busPublishFailedPayload struct {
|
||||
Stream string `json:"stream"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type busClosePayload struct {
|
||||
Drained int `json:"drained,omitempty"`
|
||||
}
|
||||
|
||||
func (mb *MessageBus) publishFailure(stream string, scope runtimeevents.Scope, err error) {
|
||||
if mb == nil || err == nil {
|
||||
return
|
||||
}
|
||||
publisher, ok := mb.eventPublisher.Load().(EventPublisher)
|
||||
if !ok || publisher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
publisher.PublishNonBlocking(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindBusPublishFailed,
|
||||
Source: runtimeevents.Source{Component: "bus", Name: stream},
|
||||
Scope: scope,
|
||||
Severity: runtimeevents.SeverityError,
|
||||
Payload: busPublishFailedPayload{
|
||||
Stream: stream,
|
||||
Error: err.Error(),
|
||||
},
|
||||
Attrs: map[string]any{
|
||||
"stream": stream,
|
||||
"error": err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (mb *MessageBus) publishCloseEvent(kind runtimeevents.Kind, drained int) {
|
||||
if mb == nil {
|
||||
return
|
||||
}
|
||||
publisher, ok := mb.eventPublisher.Load().(EventPublisher)
|
||||
if !ok || publisher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
attrs := map[string]any{}
|
||||
if drained > 0 {
|
||||
attrs["drained"] = drained
|
||||
}
|
||||
publisher.PublishNonBlocking(runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "bus"},
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
Payload: busClosePayload{Drained: drained},
|
||||
Attrs: attrs,
|
||||
})
|
||||
}
|
||||
|
||||
func runtimeScopeFromInboundContext(ctx InboundContext) runtimeevents.Scope {
|
||||
return runtimeevents.Scope{
|
||||
Channel: ctx.Channel,
|
||||
Account: ctx.Account,
|
||||
ChatID: ctx.ChatID,
|
||||
TopicID: ctx.TopicID,
|
||||
SpaceID: ctx.SpaceID,
|
||||
SpaceType: ctx.SpaceType,
|
||||
ChatType: ctx.ChatType,
|
||||
SenderID: ctx.SenderID,
|
||||
MessageID: ctx.MessageID,
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeScopeFromAudioChunk(chunk AudioChunk) runtimeevents.Scope {
|
||||
return runtimeevents.Scope{
|
||||
Channel: chunk.Channel,
|
||||
ChatID: chunk.ChatID,
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeScopeFromVoiceControl(ctrl VoiceControl) runtimeevents.Scope {
|
||||
return runtimeevents.Scope{
|
||||
ChatID: ctrl.ChatID,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func channelTypeForEvent(m *Manager, channelName string) string {
|
||||
if m == nil || m.config == nil {
|
||||
return channelName
|
||||
}
|
||||
if bc := m.config.Channels.Get(channelName); bc != nil && bc.Type != "" {
|
||||
return bc.Type
|
||||
}
|
||||
return channelName
|
||||
}
|
||||
|
||||
func (m *Manager) publishChannelEvent(
|
||||
kind runtimeevents.Kind,
|
||||
channelName string,
|
||||
scope runtimeevents.Scope,
|
||||
severity runtimeevents.Severity,
|
||||
payload any,
|
||||
) {
|
||||
if m == nil || m.runtimeEvents == nil {
|
||||
return
|
||||
}
|
||||
if scope.Channel == "" {
|
||||
scope.Channel = channelName
|
||||
}
|
||||
m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "channel", Name: channelName},
|
||||
Scope: scope,
|
||||
Severity: severity,
|
||||
Payload: payload,
|
||||
Attrs: channelEventAttrs(payload),
|
||||
})
|
||||
}
|
||||
|
||||
func channelEventAttrs(payload any) map[string]any {
|
||||
switch payload := payload.(type) {
|
||||
case ChannelLifecyclePayload:
|
||||
attrs := map[string]any{}
|
||||
setAttrString(attrs, "type", payload.Type)
|
||||
setAttrString(attrs, "error", payload.Error)
|
||||
return attrs
|
||||
case ChannelOutboundPayload:
|
||||
attrs := map[string]any{}
|
||||
if payload.Media {
|
||||
attrs["media"] = payload.Media
|
||||
}
|
||||
if payload.ContentLen > 0 {
|
||||
attrs["content_len"] = payload.ContentLen
|
||||
}
|
||||
if len(payload.MessageIDs) > 0 {
|
||||
attrs["message_ids_count"] = len(payload.MessageIDs)
|
||||
}
|
||||
setAttrString(attrs, "reply_to_message_id", payload.ReplyToMessageID)
|
||||
setAttrString(attrs, "error", payload.Error)
|
||||
if payload.Retries > 0 {
|
||||
attrs["retries"] = payload.Retries
|
||||
}
|
||||
return attrs
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func setAttrString(attrs map[string]any, key, value string) {
|
||||
if value != "" {
|
||||
attrs[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundSent(
|
||||
channelName string,
|
||||
msg bus.OutboundMessage,
|
||||
messageIDs []string,
|
||||
) {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundSent,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelOutboundPayload{
|
||||
ContentLen: len([]rune(msg.Content)),
|
||||
MessageIDs: append([]string(nil), messageIDs...),
|
||||
ReplyToMessageID: msg.ReplyToMessageID,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundQueued(
|
||||
channelName string,
|
||||
msg bus.OutboundMessage,
|
||||
) {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundQueued,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelOutboundPayload{
|
||||
ContentLen: len([]rune(msg.Content)),
|
||||
ReplyToMessageID: msg.ReplyToMessageID,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundFailed(
|
||||
channelName string,
|
||||
msg bus.OutboundMessage,
|
||||
err error,
|
||||
media bool,
|
||||
) {
|
||||
payload := ChannelOutboundPayload{
|
||||
Media: media,
|
||||
ContentLen: len([]rune(msg.Content)),
|
||||
ReplyToMessageID: msg.ReplyToMessageID,
|
||||
Retries: maxRetries,
|
||||
}
|
||||
if err != nil {
|
||||
payload.Error = err.Error()
|
||||
}
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundFailed,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityError,
|
||||
payload,
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundMediaSent(
|
||||
channelName string,
|
||||
msg bus.OutboundMediaMessage,
|
||||
messageIDs []string,
|
||||
) {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundSent,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelOutboundPayload{
|
||||
Media: true,
|
||||
MessageIDs: append([]string(nil), messageIDs...),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundMediaQueued(
|
||||
channelName string,
|
||||
msg bus.OutboundMediaMessage,
|
||||
) {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundQueued,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelOutboundPayload{Media: true},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundMediaFailed(
|
||||
channelName string,
|
||||
msg bus.OutboundMediaMessage,
|
||||
err error,
|
||||
) {
|
||||
payload := ChannelOutboundPayload{
|
||||
Media: true,
|
||||
Retries: maxRetries,
|
||||
}
|
||||
if err != nil {
|
||||
payload.Error = err.Error()
|
||||
}
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundFailed,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityError,
|
||||
payload,
|
||||
)
|
||||
}
|
||||
|
||||
func scopeFromOutboundContext(ctx bus.InboundContext) runtimeevents.Scope {
|
||||
return runtimeevents.Scope{
|
||||
Channel: ctx.Channel,
|
||||
Account: ctx.Account,
|
||||
ChatID: ctx.ChatID,
|
||||
TopicID: ctx.TopicID,
|
||||
SpaceID: ctx.SpaceID,
|
||||
SpaceType: ctx.SpaceType,
|
||||
ChatType: ctx.ChatType,
|
||||
SenderID: ctx.SenderID,
|
||||
MessageID: ctx.MessageID,
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,62 @@ func extractJSONStringField(content, field string) string {
|
||||
// Format: {"image_key": "img_xxx"}
|
||||
func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") }
|
||||
|
||||
// extractPostImageKeys extracts all image_key values from a Feishu post (rich text)
|
||||
// message. Post messages have nested arrays of elements where images appear as
|
||||
// {"tag":"img","image_key":"img_xxx"}.
|
||||
func extractPostImageKeys(rawContent string) []string {
|
||||
if rawContent == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var post map[string]json.RawMessage
|
||||
if err := json.Unmarshal([]byte(rawContent), &post); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var keys []string
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
collectFromRows := func(contentRaw json.RawMessage) {
|
||||
var rows [][]map[string]any
|
||||
if err := json.Unmarshal(contentRaw, &rows); err != nil {
|
||||
return
|
||||
}
|
||||
for _, row := range rows {
|
||||
for _, elem := range row {
|
||||
if tag, _ := elem["tag"].(string); tag == "img" {
|
||||
if ik, _ := elem["image_key"].(string); ik != "" {
|
||||
if _, dup := seen[ik]; !dup {
|
||||
seen[ik] = struct{}{}
|
||||
keys = append(keys, ik)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flat format: {"title":"...", "content":[[...]]}
|
||||
if contentRaw, ok := post["content"]; ok {
|
||||
collectFromRows(contentRaw)
|
||||
}
|
||||
|
||||
// Localized format: {"zh_cn": {"title":"...", "content":[[...]]}, ...}
|
||||
for _, raw := range post {
|
||||
var locale map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &locale); err != nil {
|
||||
continue
|
||||
}
|
||||
contentRaw, ok := locale["content"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
collectFromRows(contentRaw)
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// extractFileKey extracts the file_key from a Feishu file/audio message content JSON.
|
||||
// Format: {"file_key": "file_xxx", "file_name": "...", ...}
|
||||
func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") }
|
||||
|
||||
@@ -291,6 +291,100 @@ func TestStripMentionPlaceholders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPostImageKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
content: "not json",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "post with no images",
|
||||
content: `{"zh_cn":{"title":"Title","content":[[{"tag":"text","text":"hello"}]]}}`,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "post with one image",
|
||||
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_v3_001"}]]}}`,
|
||||
want: []string{"img_v3_001"},
|
||||
},
|
||||
{
|
||||
name: "post with multiple images",
|
||||
content: `{"zh_cn":{"title":"","content":[[{"tag":"text","text":"see"},{"tag":"img","image_key":"img_001"}],[{"tag":"img","image_key":"img_002"}]]}}`,
|
||||
want: []string{"img_001", "img_002"},
|
||||
},
|
||||
{
|
||||
name: "post with text and image mixed in row",
|
||||
content: `{"zh_cn":{"title":"","content":[[{"tag":"text","text":"hi"},{"tag":"img","image_key":"img_mix"}]]}}`,
|
||||
want: []string{"img_mix"},
|
||||
},
|
||||
{
|
||||
name: "en_us locale",
|
||||
content: `{"en_us":{"title":"","content":[[{"tag":"img","image_key":"img_en"}]]}}`,
|
||||
want: []string{"img_en"},
|
||||
},
|
||||
{
|
||||
name: "multiple locales with distinct images",
|
||||
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_zh"}]]},"en_us":{"title":"","content":[[{"tag":"img","image_key":"img_en"}]]}}`,
|
||||
want: []string{"img_zh", "img_en"},
|
||||
},
|
||||
{
|
||||
name: "duplicate image_key across locales is deduplicated",
|
||||
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_same"}]]},"en_us":{"title":"","content":[[{"tag":"img","image_key":"img_same"}]]}}`,
|
||||
want: []string{"img_same"},
|
||||
},
|
||||
{
|
||||
name: "image with empty image_key",
|
||||
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":""}]]}}`,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "flat format without locale wrapper",
|
||||
content: `{"title":"","content":[[{"tag":"img","image_key":"img_v3_flat","width":1826,"height":338}],[{"tag":"text","text":" check this image","style":[]}]]}`,
|
||||
want: []string{"img_v3_flat"},
|
||||
},
|
||||
{
|
||||
name: "flat format multiple images",
|
||||
content: `{"title":"","content":[[{"tag":"img","image_key":"img_flat_1"}],[{"tag":"img","image_key":"img_flat_2"},{"tag":"text","text":"desc"}]]}`,
|
||||
want: []string{"img_flat_1", "img_flat_2"},
|
||||
},
|
||||
{
|
||||
name: "flat format no images",
|
||||
content: `{"title":"Test","content":[[{"tag":"text","text":"just text"}]]}`,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractPostImageKeys(tt.content)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("extractPostImageKeys() = %v, want %v", got, tt.want)
|
||||
return
|
||||
}
|
||||
// Use set comparison to avoid map iteration order dependency
|
||||
gotSet := make(map[string]bool, len(got))
|
||||
for _, v := range got {
|
||||
gotSet[v] = true
|
||||
}
|
||||
for _, v := range tt.want {
|
||||
if !gotSet[v] {
|
||||
t.Errorf("extractPostImageKeys() missing expected key %q; got %v", v, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCardImageKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -803,6 +803,14 @@ func (c *FeishuChannel) downloadInboundMedia(
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
|
||||
case larkim.MsgTypePost:
|
||||
for _, imageKey := range extractPostImageKeys(rawContent) {
|
||||
ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope)
|
||||
if ref != "" {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
|
||||
case larkim.MsgTypeInteractive:
|
||||
// Extract and download images embedded in interactive cards
|
||||
feishuKeys, _ := extractCardImageKeys(rawContent)
|
||||
@@ -842,12 +850,41 @@ func (c *FeishuChannel) downloadInboundMedia(
|
||||
// downloadResource downloads a message resource (image/file) from Feishu,
|
||||
// writes it to the project media directory, and stores the reference in MediaStore.
|
||||
// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension.
|
||||
//
|
||||
// For image resources, if the primary MessageResource.Get API fails (which
|
||||
// requires im:message or im:message:readonly scope), a fallback to the
|
||||
// Image.Get API (which requires im:resource scope) is attempted. This ensures
|
||||
// image downloads succeed regardless of which permission the user has granted.
|
||||
func (c *FeishuChannel) downloadResource(
|
||||
ctx context.Context,
|
||||
messageID, fileKey, resourceType, fallbackExt string,
|
||||
store media.MediaStore,
|
||||
scope string,
|
||||
) string {
|
||||
file, filename := c.fetchResourceData(ctx, messageID, fileKey, resourceType)
|
||||
if file == nil {
|
||||
return ""
|
||||
}
|
||||
if closer, ok := file.(io.Closer); ok {
|
||||
defer closer.Close()
|
||||
}
|
||||
|
||||
if filename == "" {
|
||||
filename = fileKey
|
||||
}
|
||||
if filepath.Ext(filename) == "" && fallbackExt != "" {
|
||||
filename += fallbackExt
|
||||
}
|
||||
|
||||
return c.storeResourceFile(ctx, messageID, fileKey, filename, file, store, scope)
|
||||
}
|
||||
|
||||
// fetchResourceData tries to download a resource from Feishu, first via
|
||||
// MessageResource.Get, then falling back to Image.Get for image resources.
|
||||
func (c *FeishuChannel) fetchResourceData(
|
||||
ctx context.Context,
|
||||
messageID, fileKey, resourceType string,
|
||||
) (io.Reader, string) {
|
||||
req := larkim.NewGetMessageResourceReqBuilder().
|
||||
MessageId(messageID).
|
||||
FileKey(fileKey).
|
||||
@@ -855,41 +892,80 @@ func (c *FeishuChannel) downloadResource(
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.MessageResource.Get(ctx, req)
|
||||
if err == nil && resp.Success() && resp.File != nil {
|
||||
return resp.File, resp.FileName
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to download resource", map[string]any{
|
||||
logger.WarnCF("feishu", "MessageResource.Get failed", map[string]any{
|
||||
"message_id": messageID,
|
||||
"file_key": fileKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
} else if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
logger.WarnCF("feishu", "MessageResource.Get api error", map[string]any{
|
||||
"message_id": messageID,
|
||||
"file_key": fileKey,
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
})
|
||||
} else {
|
||||
logger.WarnCF("feishu", "MessageResource.Get returned empty file body", map[string]any{
|
||||
"message_id": messageID,
|
||||
"file_key": fileKey,
|
||||
})
|
||||
}
|
||||
|
||||
if resourceType != "image" {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
return c.fetchImageDirect(ctx, fileKey)
|
||||
}
|
||||
|
||||
// fetchImageDirect downloads an image using the Image.Get API
|
||||
// (/open-apis/im/v1/images/:image_key), which requires the im:resource scope.
|
||||
func (c *FeishuChannel) fetchImageDirect(ctx context.Context, imageKey string) (io.Reader, string) {
|
||||
req := larkim.NewGetImageReqBuilder().
|
||||
ImageKey(imageKey).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Image.Get(ctx, req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Image.Get fallback failed", map[string]any{
|
||||
"image_key": imageKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil, ""
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
logger.ErrorCF("feishu", "Resource download api error", map[string]any{
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
logger.ErrorCF("feishu", "Image.Get fallback api error", map[string]any{
|
||||
"image_key": imageKey,
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
})
|
||||
return ""
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
if resp.File == nil {
|
||||
return ""
|
||||
}
|
||||
// Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body).
|
||||
if closer, ok := resp.File.(io.Closer); ok {
|
||||
defer closer.Close()
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
filename := resp.FileName
|
||||
if filename == "" {
|
||||
filename = fileKey
|
||||
}
|
||||
// If filename still has no extension, append the fallback (like Telegram's ext parameter).
|
||||
if filepath.Ext(filename) == "" && fallbackExt != "" {
|
||||
filename += fallbackExt
|
||||
}
|
||||
logger.DebugCF("feishu", "Image downloaded via Image.Get fallback", map[string]any{
|
||||
"image_key": imageKey,
|
||||
})
|
||||
return resp.File, resp.FileName
|
||||
}
|
||||
|
||||
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
|
||||
// storeResourceFile writes downloaded resource data to disk and registers it in the MediaStore.
|
||||
func (c *FeishuChannel) storeResourceFile(
|
||||
ctx context.Context,
|
||||
messageID, fileKey, filename string,
|
||||
file io.Reader,
|
||||
store media.MediaStore,
|
||||
scope string,
|
||||
) string {
|
||||
mediaDir := media.TempDir()
|
||||
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
|
||||
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
|
||||
@@ -908,7 +984,7 @@ func (c *FeishuChannel) downloadResource(
|
||||
return ""
|
||||
}
|
||||
|
||||
if _, copyErr := io.Copy(out, resp.File); copyErr != nil {
|
||||
if _, copyErr := io.Copy(out, file); copyErr != nil {
|
||||
out.Close()
|
||||
os.Remove(localPath)
|
||||
logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{
|
||||
@@ -943,8 +1019,8 @@ func appendMediaTags(content, messageType string, mediaRefs []string) string {
|
||||
return content
|
||||
}
|
||||
|
||||
// Don't append tags to JSON content (interactive cards) - would produce invalid JSON
|
||||
if messageType == larkim.MsgTypeInteractive {
|
||||
// Don't append tags to JSON content - would produce invalid JSON
|
||||
if messageType == larkim.MsgTypeInteractive || messageType == larkim.MsgTypePost {
|
||||
return content
|
||||
}
|
||||
|
||||
|
||||
@@ -180,6 +180,13 @@ func TestAppendMediaTags(t *testing.T) {
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: `{"schema":"2.0","body":{"elements":[{"tag":"img","img_key":"img_123"}]}}`,
|
||||
},
|
||||
{
|
||||
name: "post message with images returns content unchanged",
|
||||
content: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_001"}]]}}`,
|
||||
messageType: "post",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: `{"zh_cn":{"title":"","content":[[{"tag":"img","image_key":"img_001"}]]}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
+27
-10
@@ -61,7 +61,10 @@ func NewLINEChannel(
|
||||
return nil, fmt.Errorf("line channel_secret and channel_access_token are required")
|
||||
}
|
||||
|
||||
client, err := messaging_api.NewMessagingApiAPI(cfg.ChannelAccessToken.String())
|
||||
client, err := messaging_api.NewMessagingApiAPI(
|
||||
cfg.ChannelAccessToken.String(),
|
||||
messaging_api.WithHTTPClient(&http.Client{Timeout: 30 * time.Second}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create LINE messaging client: %w", err)
|
||||
}
|
||||
@@ -456,7 +459,7 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok {
|
||||
tokenEntry := entry.(replyTokenEntry)
|
||||
if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge {
|
||||
_, err := c.client.WithContext(ctx).ReplyMessage(&messaging_api.ReplyMessageRequest{
|
||||
_, _, err := c.client.WithContext(ctx).ReplyMessageWithHttpInfo(&messaging_api.ReplyMessageRequest{
|
||||
ReplyToken: tokenEntry.token,
|
||||
Messages: []messaging_api.MessageInterface{&textMsg},
|
||||
})
|
||||
@@ -467,16 +470,18 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
})
|
||||
return nil, nil
|
||||
}
|
||||
logger.DebugC("line", "Reply API failed, falling back to Push API")
|
||||
logger.DebugCF("line", "Reply API failed, falling back to Push API", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to Push API
|
||||
_, err := c.client.WithContext(ctx).PushMessage(&messaging_api.PushMessageRequest{
|
||||
resp, _, err := c.client.WithContext(ctx).PushMessageWithHttpInfo(&messaging_api.PushMessageRequest{
|
||||
To: msg.ChatID,
|
||||
Messages: []messaging_api.MessageInterface{&textMsg},
|
||||
}, "")
|
||||
return nil, err
|
||||
return nil, classifySDKError(resp, err)
|
||||
}
|
||||
|
||||
// SendMedia implements the channels.MediaSender interface.
|
||||
@@ -502,11 +507,12 @@ func (c *LINEChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessag
|
||||
}
|
||||
|
||||
textMsg := messaging_api.TextMessage{Text: caption}
|
||||
if _, err := c.client.WithContext(ctx).PushMessage(&messaging_api.PushMessageRequest{
|
||||
resp, _, err := c.client.WithContext(ctx).PushMessageWithHttpInfo(&messaging_api.PushMessageRequest{
|
||||
To: msg.ChatID,
|
||||
Messages: []messaging_api.MessageInterface{&textMsg},
|
||||
}, ""); err != nil {
|
||||
return nil, err
|
||||
}, "")
|
||||
if err != nil {
|
||||
return nil, classifySDKError(resp, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -558,13 +564,24 @@ func (c *LINEChannel) StartTyping(ctx context.Context, chatID string) (func(), e
|
||||
return stop, nil
|
||||
}
|
||||
|
||||
// classifySDKError maps an SDK HTTP response to the project's sentinel errors.
|
||||
func classifySDKError(resp *http.Response, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if resp != nil {
|
||||
return channels.ClassifySendError(resp.StatusCode, err)
|
||||
}
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
|
||||
// sendLoading sends a loading animation indicator to the chat.
|
||||
func (c *LINEChannel) sendLoading(ctx context.Context, chatID string) error {
|
||||
_, err := c.client.WithContext(ctx).ShowLoadingAnimation(&messaging_api.ShowLoadingAnimationRequest{
|
||||
resp, _, err := c.client.WithContext(ctx).ShowLoadingAnimationWithHttpInfo(&messaging_api.ShowLoadingAnimationRequest{
|
||||
ChatId: chatID,
|
||||
LoadingSeconds: 60,
|
||||
})
|
||||
return err
|
||||
return classifySDKError(resp, err)
|
||||
}
|
||||
|
||||
// downloadContent downloads media content from the LINE content API.
|
||||
|
||||
+140
-1
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
@@ -84,6 +85,7 @@ type Manager struct {
|
||||
channels map[string]Channel
|
||||
workers map[string]*channelWorker
|
||||
bus *bus.MessageBus
|
||||
runtimeEvents runtimeevents.Bus
|
||||
config *config.Config
|
||||
mediaStore media.MediaStore
|
||||
dispatchTask *asyncTask
|
||||
@@ -98,6 +100,32 @@ type Manager struct {
|
||||
channelHashes map[string]string // channel name → config hash
|
||||
}
|
||||
|
||||
// ManagerOption configures a channel Manager.
|
||||
type ManagerOption func(*Manager)
|
||||
|
||||
// WithRuntimeEvents injects the runtime event bus used for channel observations.
|
||||
func WithRuntimeEvents(eventBus runtimeevents.Bus) ManagerOption {
|
||||
return func(m *Manager) {
|
||||
m.runtimeEvents = eventBus
|
||||
}
|
||||
}
|
||||
|
||||
// ChannelLifecyclePayload describes channel lifecycle runtime events.
|
||||
type ChannelLifecyclePayload struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ChannelOutboundPayload describes channel outbound message runtime events.
|
||||
type ChannelOutboundPayload struct {
|
||||
Media bool `json:"media,omitempty"`
|
||||
ContentLen int `json:"content_len,omitempty"`
|
||||
MessageIDs []string `json:"message_ids,omitempty"`
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Retries int `json:"retries,omitempty"`
|
||||
}
|
||||
|
||||
type toolFeedbackMessageTracker interface {
|
||||
RecordToolFeedbackMessage(chatID, messageID, content string)
|
||||
ClearToolFeedbackMessage(chatID string)
|
||||
@@ -192,6 +220,21 @@ func clearTrackedToolFeedbackMessage(
|
||||
}
|
||||
}
|
||||
|
||||
// DismissToolFeedback clears any tracked tool feedback animation for the
|
||||
// given channel/chat. This is called when a turn ends without a final
|
||||
// response (e.g., ResponseHandled tools) to stop orphaned animation goroutines.
|
||||
// outboundCtx carries topic/thread info for channels that use scoped tracker
|
||||
// keys (e.g., Telegram forum topics); may be nil for non-topic channels.
|
||||
func (m *Manager) DismissToolFeedback(
|
||||
ctx context.Context, channelName, chatID string, outboundCtx *bus.InboundContext,
|
||||
) {
|
||||
ch, ok := m.GetChannel(channelName)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
dismissTrackedToolFeedbackMessage(ctx, ch, chatID, outboundCtx)
|
||||
}
|
||||
|
||||
func prepareToolFeedbackMessageContent(ch Channel, content string) string {
|
||||
prepared := strings.TrimSpace(content)
|
||||
if prepared == "" {
|
||||
@@ -409,7 +452,12 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun
|
||||
}
|
||||
}
|
||||
|
||||
func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
|
||||
func NewManager(
|
||||
cfg *config.Config,
|
||||
messageBus *bus.MessageBus,
|
||||
store media.MediaStore,
|
||||
opts ...ManagerOption,
|
||||
) (*Manager, error) {
|
||||
m := &Manager{
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
@@ -418,6 +466,11 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.Medi
|
||||
mediaStore: store,
|
||||
channelHashes: make(map[string]string),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(m)
|
||||
}
|
||||
}
|
||||
|
||||
// Register as streaming delegate so the agent loop can obtain streamers
|
||||
messageBus.SetStreamDelegate(m)
|
||||
@@ -542,6 +595,13 @@ func (m *Manager) initChannel(typeName, channelName string) {
|
||||
setter.SetOwner(ch)
|
||||
}
|
||||
m.channels[channelName] = ch
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleInitialized,
|
||||
channelName,
|
||||
runtimeevents.Scope{Channel: channelName},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: typeName},
|
||||
)
|
||||
logger.InfoCF("channels", "Channel enabled successfully", map[string]any{
|
||||
"channel": channelName,
|
||||
"type": typeName,
|
||||
@@ -687,6 +747,13 @@ func (m *Manager) registerHTTPHandlersLocked() {
|
||||
func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) {
|
||||
if wh, ok := ch.(WebhookHandler); ok {
|
||||
m.mux.Handle(wh.WebhookPath(), wh)
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelWebhookRegistered,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
|
||||
)
|
||||
logger.InfoCF("channels", "Webhook handler registered", map[string]any{
|
||||
"channel": name,
|
||||
"path": wh.WebhookPath(),
|
||||
@@ -706,6 +773,13 @@ func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) {
|
||||
func (m *Manager) unregisterChannelHTTPHandler(name string, ch Channel) {
|
||||
if wh, ok := ch.(WebhookHandler); ok {
|
||||
m.mux.Unhandle(wh.WebhookPath())
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelWebhookUnregistered,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
|
||||
)
|
||||
logger.InfoCF("channels", "Webhook handler unregistered", map[string]any{
|
||||
"channel": name,
|
||||
"path": wh.WebhookPath(),
|
||||
@@ -744,6 +818,13 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStartFailed,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityError,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name), Error: err.Error()},
|
||||
)
|
||||
failedStarts = append(failedStarts, fmt.Errorf("channel %s: %w", name, err))
|
||||
failedNames = append(failedNames, name)
|
||||
continue
|
||||
@@ -759,6 +840,13 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStarted,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelType},
|
||||
)
|
||||
}
|
||||
|
||||
if len(m.channels) > 0 && len(m.workers) == 0 {
|
||||
@@ -895,7 +983,15 @@ func (m *Manager) StopAll(ctx context.Context) error {
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStopped,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
|
||||
)
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "All channels stopped")
|
||||
@@ -1005,11 +1101,23 @@ func (m *Manager) sendWithRetry(
|
||||
// Rate limit: wait for token
|
||||
if err := w.limiter.Wait(ctx); err != nil {
|
||||
// ctx canceled, shutting down
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelRateLimited,
|
||||
name,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityWarn,
|
||||
ChannelOutboundPayload{
|
||||
ContentLen: len([]rune(msg.Content)),
|
||||
ReplyToMessageID: msg.ReplyToMessageID,
|
||||
Error: err.Error(),
|
||||
},
|
||||
)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Pre-send: stop typing and try to edit placeholder
|
||||
if msgIDs, handled := m.preSend(ctx, name, msg, w.ch); handled {
|
||||
m.publishOutboundSent(name, msg, msgIDs)
|
||||
return msgIDs, true
|
||||
}
|
||||
|
||||
@@ -1018,6 +1126,7 @@ func (m *Manager) sendWithRetry(
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
msgIDs, lastErr = w.ch.Send(ctx, msg)
|
||||
if lastErr == nil {
|
||||
m.publishOutboundSent(name, msg, msgIDs)
|
||||
return msgIDs, true
|
||||
}
|
||||
|
||||
@@ -1057,6 +1166,7 @@ func (m *Manager) sendWithRetry(
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
m.publishOutboundFailed(name, msg, lastErr, false)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
@@ -1119,6 +1229,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
m.publishOutboundQueued(outboundMessageChannel(msg), msg)
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
@@ -1139,6 +1250,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
case w.mediaQueue <- msg:
|
||||
m.publishOutboundMediaQueued(outboundMediaChannel(msg), msg)
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
@@ -1188,6 +1300,16 @@ func (m *Manager) sendMediaWithRetry(
|
||||
|
||||
// Rate limit: wait for token
|
||||
if err := w.limiter.Wait(ctx); err != nil {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelRateLimited,
|
||||
name,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityWarn,
|
||||
ChannelOutboundPayload{
|
||||
Media: true,
|
||||
Error: err.Error(),
|
||||
},
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1199,6 +1321,7 @@ func (m *Manager) sendMediaWithRetry(
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
msgIDs, lastErr = ms.SendMedia(ctx, msg)
|
||||
if lastErr == nil {
|
||||
m.publishOutboundMediaSent(name, msg, msgIDs)
|
||||
return msgIDs, nil
|
||||
}
|
||||
|
||||
@@ -1238,6 +1361,7 @@ func (m *Manager) sendMediaWithRetry(
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
m.publishOutboundMediaFailed(name, msg, lastErr)
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
@@ -1375,6 +1499,13 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStartFailed,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityError,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name), Error: err.Error()},
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Lazily create worker only after channel starts successfully
|
||||
@@ -1388,6 +1519,13 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStarted,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelType},
|
||||
)
|
||||
deferFuncs = append(deferFuncs, func() {
|
||||
m.RegisterChannel(name, channel)
|
||||
})
|
||||
@@ -1510,6 +1648,7 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
|
||||
if wExists && w != nil {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
m.publishOutboundQueued(channelName, msg)
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -242,6 +243,57 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAllPublishesLifecycleRuntimeEvents(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().SubscribeChan(
|
||||
t.Context(),
|
||||
runtimeevents.SubscribeOptions{Name: "channel-lifecycle", Buffer: 4},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
m := newTestManager()
|
||||
m.runtimeEvents = eventBus
|
||||
m.config = &config.Config{Channels: config.ChannelsConfig{}}
|
||||
m.channels["good"] = &mockChannel{}
|
||||
m.channels["bad"] = &mockChannel{
|
||||
startFn: func(_ context.Context) error { return errors.New("bad start") },
|
||||
}
|
||||
|
||||
if err := m.StartAll(t.Context()); err != nil {
|
||||
t.Fatalf("StartAll() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := m.StopAll(stopCtx); err != nil {
|
||||
t.Errorf("StopAll() error = %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
events := []runtimeevents.Event{
|
||||
receiveChannelRuntimeEvent(t, eventsCh),
|
||||
receiveChannelRuntimeEvent(t, eventsCh),
|
||||
}
|
||||
seen := map[runtimeevents.Kind]runtimeevents.Event{}
|
||||
for _, evt := range events {
|
||||
seen[evt.Kind] = evt
|
||||
}
|
||||
if evt, ok := seen[runtimeevents.KindChannelLifecycleStarted]; !ok || evt.Scope.Channel != "good" {
|
||||
t.Fatalf("missing started event for good channel: %+v", events)
|
||||
}
|
||||
if evt, ok := seen[runtimeevents.KindChannelLifecycleStartFailed]; !ok || evt.Scope.Channel != "bad" {
|
||||
t.Fatalf("missing failed event for bad channel: %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage {
|
||||
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
|
||||
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID)
|
||||
@@ -256,6 +308,21 @@ func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMes
|
||||
return bus.NormalizeOutboundMediaMessage(msg)
|
||||
}
|
||||
|
||||
func receiveChannelRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event channel closed before expected event")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_Success(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
@@ -280,6 +347,69 @@ func TestSendWithRetry_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetryPublishesOutboundRuntimeEvents(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().OfKind(
|
||||
runtimeevents.KindChannelMessageOutboundSent,
|
||||
runtimeevents.KindChannelMessageOutboundFailed,
|
||||
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "channel-outbound", Buffer: 2})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
m := newTestManager()
|
||||
m.runtimeEvents = eventBus
|
||||
|
||||
successWorker := &channelWorker{
|
||||
ch: &mockChannel{},
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.sendWithRetry(
|
||||
context.Background(),
|
||||
"test",
|
||||
successWorker,
|
||||
testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat-1", Content: "hello"}),
|
||||
)
|
||||
sent := receiveChannelRuntimeEvent(t, eventsCh)
|
||||
if sent.Kind != runtimeevents.KindChannelMessageOutboundSent || sent.Scope.ChatID != "chat-1" {
|
||||
t.Fatalf("sent event = %+v", sent)
|
||||
}
|
||||
if sent.Attrs["content_len"] != 5 {
|
||||
t.Fatalf("sent attrs = %#v, want content_len", sent.Attrs)
|
||||
}
|
||||
|
||||
failWorker := &channelWorker{
|
||||
ch: &mockChannel{
|
||||
sendFn: func(context.Context, bus.OutboundMessage) error {
|
||||
return fmt.Errorf("send failed: %w", ErrSendFailed)
|
||||
},
|
||||
},
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.sendWithRetry(
|
||||
context.Background(),
|
||||
"test",
|
||||
failWorker,
|
||||
testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat-2", Content: "hello"}),
|
||||
)
|
||||
failed := receiveChannelRuntimeEvent(t, eventsCh)
|
||||
if failed.Kind != runtimeevents.KindChannelMessageOutboundFailed || failed.Scope.ChatID != "chat-2" {
|
||||
t.Fatalf("failed event = %+v", failed)
|
||||
}
|
||||
if failed.Severity != runtimeevents.SeverityError {
|
||||
t.Fatalf("failed severity = %q", failed.Severity)
|
||||
}
|
||||
if failed.Attrs["error"] == "" || failed.Attrs["retries"] != maxRetries {
|
||||
t.Fatalf("failed attrs = %#v, want error and retries", failed.Attrs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
|
||||
@@ -333,18 +333,33 @@ func TestIsThoughtPayload(t *testing.T) {
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "explicit thought bool",
|
||||
name: "explicit thought kind",
|
||||
payload: map[string]any{PayloadKeyKind: MessageKindThought},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "thought kind ignores case and whitespace",
|
||||
payload: map[string]any{PayloadKeyKind: " ThOuGhT "},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "legacy thought bool remains supported for inbound compatibility",
|
||||
payload: map[string]any{PayloadKeyThought: true},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "thought false",
|
||||
name: "legacy thought false",
|
||||
payload: map[string]any{PayloadKeyThought: false},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "thought string ignored",
|
||||
payload: map[string]any{PayloadKeyThought: "true"},
|
||||
name: "tool calls kind",
|
||||
payload: map[string]any{PayloadKeyKind: MessageKindToolCalls},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-string kind ignored",
|
||||
payload: map[string]any{PayloadKeyKind: true},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
@@ -380,7 +395,7 @@ func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) {
|
||||
Type: TypeMessageCreate,
|
||||
Payload: map[string]any{
|
||||
PayloadKeyContent: "internal reasoning",
|
||||
PayloadKeyThought: true,
|
||||
PayloadKeyKind: MessageKindThought,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -390,3 +405,31 @@ func TestPicoClientChannel_HandleServerMessage_IgnoresThought(t *testing.T) {
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestPicoClientChannel_HandleServerMessage_IgnoresLegacyThoughtBool(t *testing.T) {
|
||||
mb := bus.NewMessageBus()
|
||||
bc := &config.Channel{Type: config.ChannelPicoClient, Enabled: true}
|
||||
ch, err := NewPicoClientChannel(bc, &config.PicoClientSettings{
|
||||
URL: "ws://localhost:8080/ws",
|
||||
}, mb)
|
||||
if err != nil {
|
||||
t.Fatalf("NewPicoClientChannel() error = %v", err)
|
||||
}
|
||||
|
||||
ch.ctx = context.Background()
|
||||
pc := &picoConn{sessionID: "sess-thought-legacy"}
|
||||
|
||||
ch.handleServerMessage(pc, PicoMessage{
|
||||
Type: TypeMessageCreate,
|
||||
Payload: map[string]any{
|
||||
PayloadKeyContent: "legacy internal reasoning",
|
||||
PayloadKeyThought: true,
|
||||
},
|
||||
})
|
||||
|
||||
select {
|
||||
case msg := <-mb.InboundChan():
|
||||
t.Fatalf("expected no inbound publish for legacy thought payload, got %+v", msg)
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,10 +323,18 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
|
||||
|
||||
payload := map[string]any{
|
||||
PayloadKeyContent: content,
|
||||
PayloadKeyThought: isThought,
|
||||
"message_id": msgID,
|
||||
}
|
||||
if isToolCalls {
|
||||
switch {
|
||||
case isThought:
|
||||
payload[PayloadKeyKind] = MessageKindThought
|
||||
|
||||
// This field is kept solely for compatibility with legacy pico clients that
|
||||
// do not yet support the newer "kind" field.
|
||||
// DO NOT use it for any purpose other than legacy client compatibility.
|
||||
payload[PayloadKeyThought] = true
|
||||
|
||||
case isToolCalls:
|
||||
payload[PayloadKeyKind] = MessageKindToolCalls
|
||||
if toolCalls, ok := picoToolCallsPayload(msg); ok {
|
||||
payload[PayloadKeyToolCalls] = toolCalls
|
||||
@@ -457,7 +465,6 @@ func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (strin
|
||||
msgID := uuid.New().String()
|
||||
outMsg := newMessage(TypeMessageCreate, map[string]any{
|
||||
PayloadKeyContent: text,
|
||||
PayloadKeyThought: false,
|
||||
"message_id": msgID,
|
||||
})
|
||||
|
||||
|
||||
@@ -131,8 +131,8 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
|
||||
if got := payload[PayloadKeyContent]; got != "thinking trace" {
|
||||
t.Fatalf("thought content = %#v, want %q", got, "thinking trace")
|
||||
}
|
||||
if got := payload[PayloadKeyThought]; got != true {
|
||||
t.Fatalf("thought flag = %#v, want true", got)
|
||||
if got := payload[PayloadKeyKind]; got != MessageKindThought {
|
||||
t.Fatalf("thought kind = %#v, want %q", got, MessageKindThought)
|
||||
}
|
||||
if got := payload["message_id"]; got == "msg-progress" || got == nil || got == "" {
|
||||
t.Fatalf("thought message_id = %#v, want new non-progress id", got)
|
||||
@@ -193,6 +193,47 @@ func TestSend_ThoughtMessageDoesNotFinalizeTrackedToolFeedback(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendPlaceholder_EmitsNormalMessageWithoutKind(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
ch.bc.Placeholder.Enabled = true
|
||||
|
||||
if err := ch.Start(context.Background()); err != nil {
|
||||
t.Fatalf("Start() error = %v", err)
|
||||
}
|
||||
defer ch.Stop(context.Background())
|
||||
|
||||
clientConn, received, cleanup := newTestPicoWebSocket(t)
|
||||
defer cleanup()
|
||||
ch.addConnForTest(&picoConn{id: "conn-1", conn: clientConn, sessionID: "sess-1"})
|
||||
|
||||
msgID, err := ch.SendPlaceholder(context.Background(), "pico:sess-1")
|
||||
if err != nil {
|
||||
t.Fatalf("SendPlaceholder() error = %v", err)
|
||||
}
|
||||
if msgID == "" {
|
||||
t.Fatal("expected placeholder message id")
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-received:
|
||||
if msg.Type != TypeMessageCreate {
|
||||
t.Fatalf("placeholder message type = %q, want %q", msg.Type, TypeMessageCreate)
|
||||
}
|
||||
payload := msg.Payload
|
||||
if got := payload["message_id"]; got != msgID {
|
||||
t.Fatalf("placeholder message_id = %#v, want %q", got, msgID)
|
||||
}
|
||||
if got := payload[PayloadKeyContent]; got != "Thinking..." {
|
||||
t.Fatalf("placeholder content = %#v, want %q", got, "Thinking...")
|
||||
}
|
||||
if got, ok := payload[PayloadKeyKind]; ok {
|
||||
t.Fatalf("placeholder kind = %#v, want absent", got)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected placeholder message to be delivered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndAddConnection_RespectsMaxConnectionsConcurrently(t *testing.T) {
|
||||
ch := newTestPicoChannel(t)
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package pico
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Protocol message types.
|
||||
const (
|
||||
@@ -47,6 +50,13 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
|
||||
}
|
||||
|
||||
func isThoughtPayload(payload map[string]any) bool {
|
||||
kind, _ := payload[PayloadKeyKind].(string)
|
||||
if strings.EqualFold(strings.TrimSpace(kind), MessageKindThought) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Keep pico_client inbound-compatible with legacy servers that still send
|
||||
// the pre-kind boolean thought marker.
|
||||
thought, _ := payload[PayloadKeyThought].(bool)
|
||||
return thought
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ func BuiltinDefinitions() []Definition {
|
||||
return []Definition{
|
||||
startCommand(),
|
||||
helpCommand(),
|
||||
stopCommand(),
|
||||
showCommand(),
|
||||
listCommand(),
|
||||
useCommand(),
|
||||
|
||||
@@ -42,6 +42,9 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
|
||||
if !strings.Contains(reply, "/list [models|channels|agents|skills|mcp]") {
|
||||
t.Fatalf("/help reply missing /list usage, got %q", reply)
|
||||
}
|
||||
if !strings.Contains(reply, "/stop") {
|
||||
t.Fatalf("/help reply missing /stop usage, got %q", reply)
|
||||
}
|
||||
if !strings.Contains(reply, "/use <skill> <message>") {
|
||||
if !strings.Contains(reply, "/use <skill> [message]") {
|
||||
t.Fatalf("/help reply missing /use usage, got %q", reply)
|
||||
@@ -49,6 +52,59 @@ func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinStop_UsesRuntimeStopper(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
StopActiveTurn: func() (StopResult, error) {
|
||||
return StopResult{
|
||||
Stopped: true,
|
||||
TaskName: "sync the long running job",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/stop",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Task stopped. \"sync the long running job\" was canceled." {
|
||||
t.Fatalf("/stop reply=%q", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinStop_NoActiveTask(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
StopActiveTurn: func() (StopResult, error) {
|
||||
return StopResult{}, nil
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/stop",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/stop: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "No active task to stop." {
|
||||
t.Fatalf("/stop reply=%q, want no-active message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) {
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func stopCommand() Definition {
|
||||
return Definition{
|
||||
Name: "stop",
|
||||
Description: "Stop the current task",
|
||||
Usage: "/stop",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.StopActiveTurn == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
|
||||
result, err := rt.StopActiveTurn()
|
||||
if err != nil {
|
||||
return req.Reply("Failed to stop task: " + err.Error())
|
||||
}
|
||||
|
||||
return req.Reply(FormatStopReply(result))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// FormatStopReply renders a user-facing reply for a stop request.
|
||||
func FormatStopReply(result StopResult) string {
|
||||
if !result.Stopped {
|
||||
return "No active task to stop."
|
||||
}
|
||||
|
||||
taskName := compactStopTaskName(result.TaskName)
|
||||
if taskName == "" {
|
||||
return "Task stopped. Current task was canceled."
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Task stopped. %q was canceled.", taskName)
|
||||
}
|
||||
|
||||
func compactStopTaskName(taskName string) string {
|
||||
taskName = strings.Join(strings.Fields(strings.TrimSpace(taskName)), " ")
|
||||
if taskName == "" {
|
||||
return ""
|
||||
}
|
||||
if len(taskName) > 80 {
|
||||
return taskName[:77] + "..."
|
||||
}
|
||||
return taskName
|
||||
}
|
||||
@@ -36,6 +36,12 @@ type ContextStats struct {
|
||||
MessageCount int
|
||||
}
|
||||
|
||||
// StopResult describes the outcome of a stop request for the current session.
|
||||
type StopResult struct {
|
||||
Stopped bool
|
||||
TaskName string
|
||||
}
|
||||
|
||||
// Runtime provides runtime dependencies to command handlers. It is constructed
|
||||
// per-request by the agent loop so that per-request state (like session scope)
|
||||
// can coexist with long-lived callbacks (like GetModelInfo).
|
||||
@@ -55,4 +61,5 @@ type Runtime struct {
|
||||
SwitchChannel func(value string) error
|
||||
ClearHistory func() error
|
||||
ReloadConfig func() error
|
||||
StopActiveTurn func() (StopResult, error)
|
||||
}
|
||||
|
||||
+52
-39
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
providercommon "github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
)
|
||||
|
||||
// rrCounter is a global counter for round-robin load balancing across models.
|
||||
@@ -39,6 +40,7 @@ type Config struct {
|
||||
Channels ChannelsConfig `json:"channel_list" yaml:"channel_list"`
|
||||
ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration
|
||||
Gateway GatewayConfig `json:"gateway" yaml:"-"`
|
||||
Events EventsConfig `json:"events,omitempty" yaml:"-"`
|
||||
Hooks HooksConfig `json:"hooks,omitempty" yaml:"-"`
|
||||
Tools ToolsConfig `json:"tools" yaml:",inline"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat" yaml:"-"`
|
||||
@@ -276,6 +278,8 @@ type AgentDefaults struct {
|
||||
SplitOnMarker bool `json:"split_on_marker" env:"PICOCLAW_AGENTS_DEFAULTS_SPLIT_ON_MARKER"` // split messages on <|[SPLIT]|> marker
|
||||
ContextManager string `json:"context_manager,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_MANAGER"`
|
||||
ContextManagerConfig json.RawMessage `json:"context_manager_config,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_MANAGER_CONFIG"`
|
||||
MaxLLMRetries int `json:"max_llm_retries,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_LLM_RETRIES"`
|
||||
LLMRetryBackoffSecs int `json:"llm_retry_backoff_secs,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_LLM_RETRY_BACKOFF_SECS"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
@@ -553,12 +557,13 @@ type ModelConfig struct {
|
||||
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
|
||||
|
||||
// Optional optimizations
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ToolSchemaTransform string `json:"tool_schema_transform,omitempty"` // Optional tool schema compatibility transform (e.g. "simple")
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
|
||||
|
||||
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
|
||||
|
||||
@@ -595,6 +600,9 @@ func (c *ModelConfig) Validate() error {
|
||||
if c.Model == "" {
|
||||
return fmt.Errorf("model is required")
|
||||
}
|
||||
if _, err := providercommon.NormalizeToolSchemaTransform(c.ToolSchemaTransform); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -823,6 +831,7 @@ type ToolsConfig struct {
|
||||
ListDir ToolConfig `json:"list_dir" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
|
||||
Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
Serial ToolConfig `json:"serial" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SERIAL_"`
|
||||
SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"`
|
||||
Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
@@ -1465,23 +1474,24 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
|
||||
// Create a copy for the additional key
|
||||
additionalEntry := &ModelConfig{
|
||||
ModelName: expandedName,
|
||||
Provider: m.Provider,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
APIKeys: SimpleSecureStrings(keys[i]),
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
isVirtual: true,
|
||||
ModelName: expandedName,
|
||||
Provider: m.Provider,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
APIKeys: SimpleSecureStrings(keys[i]),
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ToolSchemaTransform: m.ToolSchemaTransform,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
isVirtual: true,
|
||||
}
|
||||
expanded = append(expanded, additionalEntry)
|
||||
fallbackNames = append(fallbackNames, expandedName)
|
||||
@@ -1489,22 +1499,23 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
|
||||
// Create the primary entry with first key and fallbacks
|
||||
primaryEntry := &ModelConfig{
|
||||
ModelName: originalName,
|
||||
Provider: m.Provider,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
APIKeys: SimpleSecureStrings(keys[0]),
|
||||
ModelName: originalName,
|
||||
Provider: m.Provider,
|
||||
Model: m.Model,
|
||||
APIBase: m.APIBase,
|
||||
Proxy: m.Proxy,
|
||||
AuthMethod: m.AuthMethod,
|
||||
ConnectMode: m.ConnectMode,
|
||||
Workspace: m.Workspace,
|
||||
RPM: m.RPM,
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ToolSchemaTransform: m.ToolSchemaTransform,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
UserAgent: m.UserAgent,
|
||||
APIKeys: SimpleSecureStrings(keys[0]),
|
||||
}
|
||||
|
||||
// Prepend new fallbacks to existing ones
|
||||
@@ -1548,6 +1559,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
return t.Message.Enabled
|
||||
case "read_file":
|
||||
return t.ReadFile.Enabled
|
||||
case "serial":
|
||||
return t.Serial.Enabled
|
||||
case "spawn":
|
||||
return t.Spawn.Enabled
|
||||
case "spawn_status":
|
||||
|
||||
@@ -1998,6 +1998,36 @@ func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_ToolSchemaTransformRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
cfg := &Config{
|
||||
Version: CurrentVersion,
|
||||
ModelList: []*ModelConfig{
|
||||
{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test",
|
||||
APIKeys: SimpleSecureStrings("sk-test"),
|
||||
ToolSchemaTransform: "simple",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig error: %v", err)
|
||||
}
|
||||
|
||||
if got := loaded.ModelList[0].ToolSchemaTransform; got != "simple" {
|
||||
t.Fatalf("ToolSchemaTransform = %q, want %q", got, "simple")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
|
||||
@@ -39,7 +39,9 @@ func DefaultConfig() *Config {
|
||||
MaxArgsLength: 300,
|
||||
SeparateMessages: false,
|
||||
},
|
||||
SplitOnMarker: false,
|
||||
SplitOnMarker: false,
|
||||
MaxLLMRetries: 2,
|
||||
LLMRetryBackoffSecs: 2,
|
||||
},
|
||||
},
|
||||
Session: SessionConfig{
|
||||
@@ -294,6 +296,9 @@ func DefaultConfig() *Config {
|
||||
HotReload: false,
|
||||
LogLevel: DefaultGatewayLogLevel,
|
||||
},
|
||||
Events: EventsConfig{
|
||||
Logging: defaultEventLoggingConfig(),
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
FilterSensitiveData: true,
|
||||
FilterMinLength: 8,
|
||||
@@ -435,6 +440,9 @@ func DefaultConfig() *Config {
|
||||
Mode: ReadFileModeBytes,
|
||||
MaxReadFileSize: 64 * 1024, // 64KB
|
||||
},
|
||||
Serial: ToolConfig{
|
||||
Enabled: false, // Hardware tool - requires host serial ports
|
||||
},
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
package config
|
||||
|
||||
// EventsConfig groups runtime event configuration.
|
||||
type EventsConfig struct {
|
||||
Logging EventLoggingConfig `json:"logging,omitempty" envPrefix:"PICOCLAW_EVENTS_LOGGING_"`
|
||||
}
|
||||
|
||||
// EventLoggingConfig controls centralized runtime event logging.
|
||||
type EventLoggingConfig struct {
|
||||
// Enabled controls whether runtime events are printed by the built-in logger.
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
// Include contains exact event kinds or glob patterns such as "agent.*" or "*".
|
||||
Include []string `json:"include,omitempty" env:"INCLUDE"`
|
||||
// Exclude contains exact event kinds or glob patterns to suppress after Include matches.
|
||||
Exclude []string `json:"exclude,omitempty" env:"EXCLUDE"`
|
||||
// MinSeverity filters out events below the configured severity: debug, info, warn, or error.
|
||||
MinSeverity string `json:"min_severity,omitempty" env:"MIN_SEVERITY"`
|
||||
// IncludePayload adds the raw payload to logs. Leave disabled unless detailed diagnostics are needed.
|
||||
IncludePayload bool `json:"include_payload,omitempty" env:"INCLUDE_PAYLOAD"`
|
||||
}
|
||||
|
||||
// DefaultEventLoggingInclude keeps the pre-existing behavior where agent events
|
||||
// are printed, while non-agent runtime events are published for subscribers only.
|
||||
var DefaultEventLoggingInclude = []string{"agent.*"}
|
||||
|
||||
// EffectiveEventLoggingConfig returns a logging config with stable defaults.
|
||||
func EffectiveEventLoggingConfig(cfg *Config) EventLoggingConfig {
|
||||
if cfg == nil {
|
||||
return defaultEventLoggingConfig()
|
||||
}
|
||||
|
||||
out := cfg.Events.Logging
|
||||
if out.MinSeverity == "" {
|
||||
out.MinSeverity = "info"
|
||||
}
|
||||
if len(out.Include) == 0 {
|
||||
out.Include = append([]string(nil), DefaultEventLoggingInclude...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func defaultEventLoggingConfig() EventLoggingConfig {
|
||||
return EventLoggingConfig{
|
||||
Enabled: true,
|
||||
Include: append([]string(nil), DefaultEventLoggingInclude...),
|
||||
MinSeverity: "info",
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultEventLoggingConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
logCfg := EffectiveEventLoggingConfig(cfg)
|
||||
|
||||
if !logCfg.Enabled {
|
||||
t.Fatal("default event logging should be enabled")
|
||||
}
|
||||
if !reflect.DeepEqual(logCfg.Include, []string{"agent.*"}) {
|
||||
t.Fatalf("default include = %#v, want agent.*", logCfg.Include)
|
||||
}
|
||||
if logCfg.MinSeverity != "info" {
|
||||
t.Fatalf("default min severity = %q, want info", logCfg.MinSeverity)
|
||||
}
|
||||
if logCfg.IncludePayload {
|
||||
t.Fatal("default event logging should not include raw payloads")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigEventLoggingOverrides(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
data := []byte(`{
|
||||
"version": 3,
|
||||
"events": {
|
||||
"logging": {
|
||||
"enabled": false,
|
||||
"include": ["gateway.*"],
|
||||
"exclude": ["gateway.ready"],
|
||||
"min_severity": "warn",
|
||||
"include_payload": true
|
||||
}
|
||||
}
|
||||
}`)
|
||||
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
logCfg := EffectiveEventLoggingConfig(cfg)
|
||||
|
||||
if logCfg.Enabled {
|
||||
t.Fatal("loaded event logging enabled = true, want false")
|
||||
}
|
||||
if !reflect.DeepEqual(logCfg.Include, []string{"gateway.*"}) {
|
||||
t.Fatalf("loaded include = %#v, want gateway.*", logCfg.Include)
|
||||
}
|
||||
if !reflect.DeepEqual(logCfg.Exclude, []string{"gateway.ready"}) {
|
||||
t.Fatalf("loaded exclude = %#v, want gateway.ready", logCfg.Exclude)
|
||||
}
|
||||
if logCfg.MinSeverity != "warn" {
|
||||
t.Fatalf("loaded min severity = %q, want warn", logCfg.MinSeverity)
|
||||
}
|
||||
if !logCfg.IncludePayload {
|
||||
t.Fatal("loaded include_payload = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigEventLoggingEnvOverrides(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_EVENTS_LOGGING_ENABLED", "false")
|
||||
t.Setenv("PICOCLAW_EVENTS_LOGGING_INCLUDE", "gateway.*,channel.lifecycle.*")
|
||||
t.Setenv("PICOCLAW_EVENTS_LOGGING_EXCLUDE", "gateway.ready")
|
||||
t.Setenv("PICOCLAW_EVENTS_LOGGING_MIN_SEVERITY", "error")
|
||||
t.Setenv("PICOCLAW_EVENTS_LOGGING_INCLUDE_PAYLOAD", "true")
|
||||
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
data := []byte(`{"version": 3}`)
|
||||
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
logCfg := EffectiveEventLoggingConfig(cfg)
|
||||
|
||||
if logCfg.Enabled {
|
||||
t.Fatal("env enabled override = true, want false")
|
||||
}
|
||||
if !reflect.DeepEqual(logCfg.Include, []string{"gateway.*", "channel.lifecycle.*"}) {
|
||||
t.Fatalf("env include = %#v, want gateway/channel lifecycle", logCfg.Include)
|
||||
}
|
||||
if !reflect.DeepEqual(logCfg.Exclude, []string{"gateway.ready"}) {
|
||||
t.Fatalf("env exclude = %#v, want gateway.ready", logCfg.Exclude)
|
||||
}
|
||||
if logCfg.MinSeverity != "error" {
|
||||
t.Fatalf("env min severity = %q, want error", logCfg.MinSeverity)
|
||||
}
|
||||
if !logCfg.IncludePayload {
|
||||
t.Fatal("env include_payload = false, want true")
|
||||
}
|
||||
}
|
||||
@@ -158,6 +158,15 @@ func TestModelConfig_Validate(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid tool schema transform",
|
||||
config: ModelConfig{
|
||||
ModelName: "test",
|
||||
Model: "openai/gpt-4o",
|
||||
ToolSchemaTransform: "simple",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing model_name",
|
||||
config: ModelConfig{
|
||||
@@ -177,6 +186,15 @@ func TestModelConfig_Validate(t *testing.T) {
|
||||
config: ModelConfig{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool schema transform",
|
||||
config: ModelConfig{
|
||||
ModelName: "test",
|
||||
Model: "openai/gpt-4o",
|
||||
ToolSchemaTransform: "invalid",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -187,15 +187,16 @@ func TestExpandMultiKeyModels_Deduplication(t *testing.T) {
|
||||
|
||||
func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
|
||||
modelCfg := &ModelConfig{
|
||||
ModelName: "gpt-4",
|
||||
Provider: "openrouter",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: "https://api.example.com",
|
||||
Proxy: "http://proxy:8080",
|
||||
RPM: 60,
|
||||
MaxTokensField: "max_completion_tokens",
|
||||
RequestTimeout: 30,
|
||||
ThinkingLevel: "high",
|
||||
ModelName: "gpt-4",
|
||||
Provider: "openrouter",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: "https://api.example.com",
|
||||
Proxy: "http://proxy:8080",
|
||||
RPM: 60,
|
||||
MaxTokensField: "max_completion_tokens",
|
||||
RequestTimeout: 30,
|
||||
ThinkingLevel: "high",
|
||||
ToolSchemaTransform: "simple",
|
||||
}
|
||||
modelCfg.APIKeys = SimpleSecureStrings("key0", "key1") // Use internal field for multi-key testing
|
||||
models := []*ModelConfig{modelCfg}
|
||||
@@ -225,6 +226,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
|
||||
if primary.ThinkingLevel != "high" {
|
||||
t.Errorf("expected thinking_level preserved, got %q", primary.ThinkingLevel)
|
||||
}
|
||||
if primary.ToolSchemaTransform != "simple" {
|
||||
t.Errorf("expected tool_schema_transform preserved, got %q", primary.ToolSchemaTransform)
|
||||
}
|
||||
|
||||
// Check additional entry also preserves fields
|
||||
additional := result[0]
|
||||
@@ -237,6 +241,9 @@ func TestExpandMultiKeyModels_PreservesOtherFields(t *testing.T) {
|
||||
if additional.RPM != 60 {
|
||||
t.Errorf("expected additional rpm preserved, got %d", additional.RPM)
|
||||
}
|
||||
if additional.ToolSchemaTransform != "simple" {
|
||||
t.Errorf("expected additional tool_schema_transform preserved, got %q", additional.ToolSchemaTransform)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandMultiKeyModels_IsVirtualFlag(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var globalEventSeq atomic.Uint64
|
||||
|
||||
// Bus publishes runtime events and creates filtered channels.
|
||||
type Bus interface {
|
||||
Publish(ctx context.Context, evt Event) PublishResult
|
||||
PublishNonBlocking(evt Event) PublishResult
|
||||
Channel() EventChannel
|
||||
Close() error
|
||||
Stats() Stats
|
||||
}
|
||||
|
||||
// PublishResult reports per-publish delivery outcomes.
|
||||
type PublishResult struct {
|
||||
Matched int
|
||||
Delivered int
|
||||
Dropped int
|
||||
Blocked int
|
||||
Closed bool
|
||||
}
|
||||
|
||||
// EventBus is an in-process runtime event broadcaster.
|
||||
type EventBus struct {
|
||||
mu sync.RWMutex
|
||||
subs map[uint64]*eventSubscription
|
||||
orderedSubs []*eventSubscription
|
||||
closed bool
|
||||
|
||||
nextSubID atomic.Uint64
|
||||
published atomic.Uint64
|
||||
matched atomic.Uint64
|
||||
delivered atomic.Uint64
|
||||
dropped atomic.Uint64
|
||||
blocked atomic.Uint64
|
||||
}
|
||||
|
||||
var _ Bus = (*EventBus)(nil)
|
||||
|
||||
// NewBus creates an in-process runtime event bus.
|
||||
func NewBus() *EventBus {
|
||||
return &EventBus{
|
||||
subs: make(map[uint64]*eventSubscription),
|
||||
}
|
||||
}
|
||||
|
||||
// Publish broadcasts evt to subscriptions whose filters match it.
|
||||
func (b *EventBus) Publish(ctx context.Context, evt Event) PublishResult {
|
||||
return b.publish(ctx, evt, false)
|
||||
}
|
||||
|
||||
// PublishNonBlocking broadcasts evt without waiting for subscriber queue capacity.
|
||||
func (b *EventBus) PublishNonBlocking(evt Event) PublishResult {
|
||||
return b.publish(context.Background(), evt, true)
|
||||
}
|
||||
|
||||
func (b *EventBus) publish(ctx context.Context, evt Event, nonBlocking bool) PublishResult {
|
||||
if b == nil {
|
||||
return PublishResult{Closed: true}
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if evt.Time.IsZero() {
|
||||
evt.Time = time.Now()
|
||||
}
|
||||
if evt.ID == "" {
|
||||
evt.ID = nextEventID()
|
||||
}
|
||||
|
||||
subs, closed := b.snapshotSubscribers()
|
||||
if closed {
|
||||
return PublishResult{Closed: true}
|
||||
}
|
||||
|
||||
b.published.Add(1)
|
||||
result := PublishResult{}
|
||||
|
||||
for _, sub := range subs {
|
||||
if !matchesFilters(sub.filters, evt) {
|
||||
continue
|
||||
}
|
||||
|
||||
result.Matched++
|
||||
b.matched.Add(1)
|
||||
|
||||
delivery := sub.enqueue(ctx, evt, nonBlocking)
|
||||
if delivery.closed {
|
||||
continue
|
||||
}
|
||||
result.Delivered += delivery.delivered
|
||||
result.Dropped += delivery.dropped
|
||||
result.Blocked += delivery.blocked
|
||||
b.delivered.Add(uint64(delivery.delivered))
|
||||
b.dropped.Add(uint64(delivery.dropped))
|
||||
b.blocked.Add(uint64(delivery.blocked))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Channel returns the root event channel for this bus.
|
||||
func (b *EventBus) Channel() EventChannel {
|
||||
return eventChannel{bus: b}
|
||||
}
|
||||
|
||||
// Close closes the bus and all active subscriptions.
|
||||
func (b *EventBus) Close() error {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
if b.closed {
|
||||
b.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
b.closed = true
|
||||
subs := b.orderedSubs
|
||||
b.subs = nil
|
||||
b.orderedSubs = nil
|
||||
b.mu.Unlock()
|
||||
|
||||
for _, sub := range subs {
|
||||
sub.closeInput()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns a snapshot of bus and subscription counters.
|
||||
func (b *EventBus) Stats() Stats {
|
||||
if b == nil {
|
||||
return Stats{Closed: true}
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
closed := b.closed
|
||||
subs := b.orderedSubs
|
||||
b.mu.RUnlock()
|
||||
|
||||
stats := Stats{
|
||||
Published: b.published.Load(),
|
||||
Matched: b.matched.Load(),
|
||||
Delivered: b.delivered.Load(),
|
||||
Dropped: b.dropped.Load(),
|
||||
Blocked: b.blocked.Load(),
|
||||
Closed: closed,
|
||||
Subscribers: len(subs),
|
||||
SubscriberStats: make([]SubscriberStats, 0, len(subs)),
|
||||
}
|
||||
for _, sub := range subs {
|
||||
stats.SubscriberStats = append(stats.SubscriberStats, sub.Stats())
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
func (b *EventBus) subscribe(
|
||||
ctx context.Context,
|
||||
filters []Filter,
|
||||
opts SubscribeOptions,
|
||||
handler Handler,
|
||||
once bool,
|
||||
) (Subscription, error) {
|
||||
if b == nil {
|
||||
return nil, ErrBusClosed
|
||||
}
|
||||
|
||||
id := b.nextSubID.Add(1)
|
||||
sub := newSubscription(b, id, filters, opts, handler, once)
|
||||
|
||||
b.mu.Lock()
|
||||
if b.closed {
|
||||
b.mu.Unlock()
|
||||
sub.closeInput()
|
||||
return nil, ErrBusClosed
|
||||
}
|
||||
b.subs[id] = sub
|
||||
b.rebuildOrderedSubscribersLocked()
|
||||
b.mu.Unlock()
|
||||
|
||||
if handler != nil {
|
||||
go sub.run(ctx)
|
||||
}
|
||||
sub.watchContext(ctx)
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (b *EventBus) unsubscribe(id uint64) {
|
||||
b.mu.Lock()
|
||||
sub, ok := b.subs[id]
|
||||
if ok {
|
||||
delete(b.subs, id)
|
||||
b.rebuildOrderedSubscribersLocked()
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
sub.closeInput()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *EventBus) snapshotSubscribers() ([]*eventSubscription, bool) {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.closed {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
return b.orderedSubs, false
|
||||
}
|
||||
|
||||
func (b *EventBus) rebuildOrderedSubscribersLocked() {
|
||||
subs := make([]*eventSubscription, 0, len(b.subs))
|
||||
for _, sub := range b.subs {
|
||||
subs = append(subs, sub)
|
||||
}
|
||||
sortSubscriptions(subs)
|
||||
b.orderedSubs = subs
|
||||
}
|
||||
|
||||
func sortSubscriptions(subs []*eventSubscription) {
|
||||
sort.Slice(subs, func(i, j int) bool {
|
||||
if subs[i].opts.Priority == subs[j].opts.Priority {
|
||||
return subs[i].id < subs[j].id
|
||||
}
|
||||
return subs[i].opts.Priority > subs[j].opts.Priority
|
||||
})
|
||||
}
|
||||
|
||||
func nextEventID() string {
|
||||
id := globalEventSeq.Add(1)
|
||||
return "evt-" + strconv.FormatUint(id, 10)
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package events
|
||||
|
||||
import "context"
|
||||
|
||||
// EventChannel is a filtered view over an EventBus.
|
||||
type EventChannel interface {
|
||||
Filter(filter Filter) EventChannel
|
||||
OfKind(kinds ...Kind) EventChannel
|
||||
KindPrefix(prefix string) EventChannel
|
||||
Source(component string, names ...string) EventChannel
|
||||
Scope(scope ScopeFilter) EventChannel
|
||||
|
||||
Subscribe(ctx context.Context, opts SubscribeOptions, handler Handler) (Subscription, error)
|
||||
SubscribeChan(ctx context.Context, opts SubscribeOptions) (Subscription, <-chan Event, error)
|
||||
SubscribeOnce(ctx context.Context, opts SubscribeOptions, handler Handler) (Subscription, error)
|
||||
}
|
||||
|
||||
type eventChannel struct {
|
||||
bus *EventBus
|
||||
filters []Filter
|
||||
}
|
||||
|
||||
// Filter returns a new EventChannel with filter appended.
|
||||
func (c eventChannel) Filter(filter Filter) EventChannel {
|
||||
filters := append([]Filter(nil), c.filters...)
|
||||
if filter != nil {
|
||||
filters = append(filters, filter)
|
||||
}
|
||||
return eventChannel{bus: c.bus, filters: filters}
|
||||
}
|
||||
|
||||
// OfKind returns a new EventChannel matching any of kinds.
|
||||
func (c eventChannel) OfKind(kinds ...Kind) EventChannel {
|
||||
return c.Filter(MatchKind(kinds...))
|
||||
}
|
||||
|
||||
// KindPrefix returns a new EventChannel matching events with the kind prefix.
|
||||
func (c eventChannel) KindPrefix(prefix string) EventChannel {
|
||||
return c.Filter(MatchKindPrefix(prefix))
|
||||
}
|
||||
|
||||
// Source returns a new EventChannel matching source component and optional names.
|
||||
func (c eventChannel) Source(component string, names ...string) EventChannel {
|
||||
return c.Filter(MatchSource(component, names...))
|
||||
}
|
||||
|
||||
// Scope returns a new EventChannel matching non-empty scope fields.
|
||||
func (c eventChannel) Scope(scope ScopeFilter) EventChannel {
|
||||
return c.Filter(MatchScope(scope))
|
||||
}
|
||||
|
||||
// Subscribe registers handler for events matching this channel.
|
||||
func (c eventChannel) Subscribe(ctx context.Context, opts SubscribeOptions, handler Handler) (Subscription, error) {
|
||||
if handler == nil {
|
||||
return nil, ErrNilHandler
|
||||
}
|
||||
return c.bus.subscribe(ctx, c.filters, opts, handler, false)
|
||||
}
|
||||
|
||||
// SubscribeChan registers a channel subscription for events matching this channel.
|
||||
func (c eventChannel) SubscribeChan(ctx context.Context, opts SubscribeOptions) (Subscription, <-chan Event, error) {
|
||||
sub, err := c.bus.subscribe(ctx, c.filters, opts, nil, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return sub, sub.(*eventSubscription).ch, nil
|
||||
}
|
||||
|
||||
// SubscribeOnce registers handler and closes the subscription after the first event.
|
||||
func (c eventChannel) SubscribeOnce(ctx context.Context, opts SubscribeOptions, handler Handler) (Subscription, error) {
|
||||
if handler == nil {
|
||||
return nil, ErrNilHandler
|
||||
}
|
||||
return c.bus.subscribe(ctx, c.filters, opts, handler, true)
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
// Package events provides the process-local runtime event bus used to observe
|
||||
// PicoClaw components without coupling them to agent-specific event envelopes.
|
||||
package events
|
||||
@@ -0,0 +1,254 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPublishDeliversToMatchingSubscriber(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
_, ch, err := bus.Channel().OfKind(KindAgentTurnStart).SubscribeChan(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "turn-starts", Buffer: 1},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
unmatched := bus.Publish(context.Background(), Event{Kind: KindAgentTurnEnd})
|
||||
if unmatched.Matched != 0 || unmatched.Delivered != 0 {
|
||||
t.Fatalf("unmatched Publish = %+v, want no delivery", unmatched)
|
||||
}
|
||||
|
||||
result := bus.Publish(context.Background(), Event{Kind: KindAgentTurnStart})
|
||||
if result.Matched != 1 || result.Delivered != 1 || result.Dropped != 0 {
|
||||
t.Fatalf("Publish = %+v, want one delivered event", result)
|
||||
}
|
||||
|
||||
evt := receiveEvent(t, ch)
|
||||
if evt.Kind != KindAgentTurnStart {
|
||||
t.Fatalf("event kind = %q, want %q", evt.Kind, KindAgentTurnStart)
|
||||
}
|
||||
if evt.ID == "" {
|
||||
t.Fatal("event ID is empty")
|
||||
}
|
||||
if evt.Time.IsZero() {
|
||||
t.Fatal("event Time is zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDropNewestIncrementsStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
sub, _, err := bus.Channel().SubscribeChan(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "drop-newest", Buffer: 1, Backpressure: DropNewest},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
first := bus.Publish(context.Background(), Event{Kind: KindAgentTurnStart})
|
||||
if first.Delivered != 1 || first.Dropped != 0 {
|
||||
t.Fatalf("first Publish = %+v, want one delivered event", first)
|
||||
}
|
||||
|
||||
second := bus.Publish(context.Background(), Event{Kind: KindAgentTurnEnd})
|
||||
if second.Delivered != 0 || second.Dropped != 1 {
|
||||
t.Fatalf("second Publish = %+v, want one dropped event", second)
|
||||
}
|
||||
|
||||
if got := sub.Stats().Dropped; got != 1 {
|
||||
t.Fatalf("subscription dropped = %d, want 1", got)
|
||||
}
|
||||
if got := bus.Stats().Dropped; got != 1 {
|
||||
t.Fatalf("bus dropped = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDropOldestKeepsNewestEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
sub, ch, err := bus.Channel().SubscribeChan(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "drop-oldest", Buffer: 1, Backpressure: DropOldest},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
bus.Publish(context.Background(), Event{Kind: Kind("test.old"), Payload: "old"})
|
||||
result := bus.Publish(context.Background(), Event{Kind: Kind("test.new"), Payload: "new"})
|
||||
if result.Delivered != 1 || result.Dropped != 1 {
|
||||
t.Fatalf("Publish = %+v, want replacement delivery", result)
|
||||
}
|
||||
|
||||
evt := receiveEvent(t, ch)
|
||||
if evt.Payload != "new" {
|
||||
t.Fatalf("payload = %v, want new", evt.Payload)
|
||||
}
|
||||
if got := sub.Stats().Dropped; got != 1 {
|
||||
t.Fatalf("subscription dropped = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlockRespectsContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
_, _, err := bus.Channel().SubscribeChan(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "block", Buffer: 1, Backpressure: Block},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
first := bus.Publish(context.Background(), Event{Kind: Kind("test.first")})
|
||||
if first.Delivered != 1 {
|
||||
t.Fatalf("first Publish = %+v, want one delivered event", first)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
second := bus.Publish(ctx, Event{Kind: Kind("test.second")})
|
||||
if second.Blocked != 1 || second.Dropped != 1 || second.Delivered != 0 {
|
||||
t.Fatalf("second Publish = %+v, want one blocked drop", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishNonBlockingDropsForFullBlockSubscriber(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
sub, _, err := bus.Channel().SubscribeChan(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "block", Buffer: 1, Backpressure: Block},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
first := bus.PublishNonBlocking(Event{Kind: Kind("test.first")})
|
||||
if first.Delivered != 1 {
|
||||
t.Fatalf("first PublishNonBlocking = %+v, want one delivered event", first)
|
||||
}
|
||||
|
||||
resultCh := make(chan PublishResult, 1)
|
||||
go func() {
|
||||
resultCh <- bus.PublishNonBlocking(Event{Kind: Kind("test.second")})
|
||||
}()
|
||||
|
||||
select {
|
||||
case second := <-resultCh:
|
||||
if second.Matched != 1 || second.Delivered != 0 || second.Dropped != 1 || second.Blocked != 0 {
|
||||
t.Fatalf("second PublishNonBlocking = %+v, want non-blocking drop", second)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("PublishNonBlocking blocked on full Block subscriber")
|
||||
}
|
||||
|
||||
if got := sub.Stats().Dropped; got != 1 {
|
||||
t.Fatalf("subscription dropped = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatsSubscribersKeepPriorityOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
low, _, err := bus.Channel().SubscribeChan(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "low", Priority: -1},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan low failed: %v", err)
|
||||
}
|
||||
high, _, err := bus.Channel().SubscribeChan(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "high", Priority: 10},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan high failed: %v", err)
|
||||
}
|
||||
peer, _, err := bus.Channel().SubscribeChan(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "peer", Priority: 10},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan peer failed: %v", err)
|
||||
}
|
||||
|
||||
stats := bus.Stats()
|
||||
got := []string{
|
||||
stats.SubscriberStats[0].Name,
|
||||
stats.SubscriberStats[1].Name,
|
||||
stats.SubscriberStats[2].Name,
|
||||
}
|
||||
want := []string{"high", "peer", "low"}
|
||||
if got[0] != want[0] || got[1] != want[1] || got[2] != want[2] {
|
||||
t.Fatalf("subscriber order = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if err := high.Close(); err != nil {
|
||||
t.Fatalf("Close high failed: %v", err)
|
||||
}
|
||||
|
||||
stats = bus.Stats()
|
||||
got = []string{
|
||||
stats.SubscriberStats[0].Name,
|
||||
stats.SubscriberStats[1].Name,
|
||||
}
|
||||
want = []string{"peer", "low"}
|
||||
if got[0] != want[0] || got[1] != want[1] {
|
||||
t.Fatalf("subscriber order after unsubscribe = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if err := peer.Close(); err != nil {
|
||||
t.Fatalf("Close peer failed: %v", err)
|
||||
}
|
||||
if err := low.Close(); err != nil {
|
||||
t.Fatalf("Close low failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func receiveEvent(t *testing.T, ch <-chan Event) Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("event channel closed before receive")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for event")
|
||||
return Event{}
|
||||
}
|
||||
}
|
||||
|
||||
func closeBus(t *testing.T, bus *EventBus) {
|
||||
t.Helper()
|
||||
|
||||
if err := bus.Close(); err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package events
|
||||
|
||||
import "strings"
|
||||
|
||||
// Filter decides whether an event should pass through an EventChannel.
|
||||
type Filter func(Event) bool
|
||||
|
||||
// ScopeFilter matches selected non-empty fields against Event.Scope.
|
||||
type ScopeFilter struct {
|
||||
AgentID string
|
||||
SessionKey string
|
||||
TurnID string
|
||||
Channel string
|
||||
ChatID string
|
||||
MessageID string
|
||||
}
|
||||
|
||||
// MatchKind matches events whose kind is in kinds. Empty kinds match all events.
|
||||
func MatchKind(kinds ...Kind) Filter {
|
||||
if len(kinds) == 0 {
|
||||
return matchAll
|
||||
}
|
||||
|
||||
allowed := make(map[Kind]struct{}, len(kinds))
|
||||
for _, kind := range kinds {
|
||||
allowed[kind] = struct{}{}
|
||||
}
|
||||
|
||||
return func(evt Event) bool {
|
||||
_, ok := allowed[evt.Kind]
|
||||
return ok
|
||||
}
|
||||
}
|
||||
|
||||
// MatchKindPrefix matches events whose kind starts with prefix.
|
||||
func MatchKindPrefix(prefix string) Filter {
|
||||
if prefix == "" {
|
||||
return matchAll
|
||||
}
|
||||
return func(evt Event) bool {
|
||||
return strings.HasPrefix(evt.Kind.String(), prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// MatchSource matches events emitted by component and, optionally, one of names.
|
||||
func MatchSource(component string, names ...string) Filter {
|
||||
if component == "" && len(names) == 0 {
|
||||
return matchAll
|
||||
}
|
||||
|
||||
allowedNames := make(map[string]struct{}, len(names))
|
||||
for _, name := range names {
|
||||
allowedNames[name] = struct{}{}
|
||||
}
|
||||
|
||||
return func(evt Event) bool {
|
||||
if component != "" && evt.Source.Component != component {
|
||||
return false
|
||||
}
|
||||
if len(allowedNames) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := allowedNames[evt.Source.Name]
|
||||
return ok
|
||||
}
|
||||
}
|
||||
|
||||
// MatchScope matches events whose Scope contains all non-empty filter fields.
|
||||
func MatchScope(scope ScopeFilter) Filter {
|
||||
if scope == (ScopeFilter{}) {
|
||||
return matchAll
|
||||
}
|
||||
|
||||
return func(evt Event) bool {
|
||||
return matchesString(scope.AgentID, evt.Scope.AgentID) &&
|
||||
matchesString(scope.SessionKey, evt.Scope.SessionKey) &&
|
||||
matchesString(scope.TurnID, evt.Scope.TurnID) &&
|
||||
matchesString(scope.Channel, evt.Scope.Channel) &&
|
||||
matchesString(scope.ChatID, evt.Scope.ChatID) &&
|
||||
matchesString(scope.MessageID, evt.Scope.MessageID)
|
||||
}
|
||||
}
|
||||
|
||||
// And combines filters and short-circuits on the first non-match.
|
||||
func And(filters ...Filter) Filter {
|
||||
if len(filters) == 0 {
|
||||
return matchAll
|
||||
}
|
||||
|
||||
return func(evt Event) bool {
|
||||
for _, filter := range filters {
|
||||
if filter != nil && !filter(evt) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Or combines filters and short-circuits on the first match.
|
||||
func Or(filters ...Filter) Filter {
|
||||
if len(filters) == 0 {
|
||||
return matchAll
|
||||
}
|
||||
|
||||
return func(evt Event) bool {
|
||||
for _, filter := range filters {
|
||||
if filter == nil || filter(evt) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func matchAll(Event) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func matchesString(want, got string) bool {
|
||||
return want == "" || want == got
|
||||
}
|
||||
|
||||
func matchesFilters(filters []Filter, evt Event) bool {
|
||||
for _, filter := range filters {
|
||||
if filter != nil && !filter(evt) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package events
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFilterKindPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
event Event
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "matches agent prefix",
|
||||
prefix: "agent.",
|
||||
event: Event{Kind: KindAgentTurnStart},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "rejects different prefix",
|
||||
prefix: "channel.",
|
||||
event: Event{Kind: KindAgentTurnStart},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty prefix matches all",
|
||||
prefix: "",
|
||||
event: Event{Kind: KindAgentTurnStart},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := MatchKindPrefix(tt.prefix)(tt.event); got != tt.want {
|
||||
t.Fatalf("MatchKindPrefix(%q) = %v, want %v", tt.prefix, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterScope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
evt := Event{
|
||||
Scope: Scope{
|
||||
AgentID: "agent-a",
|
||||
SessionKey: "session-1",
|
||||
TurnID: "turn-1",
|
||||
Channel: "telegram",
|
||||
ChatID: "chat-1",
|
||||
MessageID: "msg-1",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scope ScopeFilter
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty filter matches",
|
||||
scope: ScopeFilter{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "matches selected fields",
|
||||
scope: ScopeFilter{
|
||||
AgentID: "agent-a",
|
||||
ChatID: "chat-1",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "rejects mismatched field",
|
||||
scope: ScopeFilter{
|
||||
AgentID: "agent-a",
|
||||
MessageID: "msg-2",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := MatchScope(tt.scope)(evt); got != tt.want {
|
||||
t.Fatalf("MatchScope(%+v) = %v, want %v", tt.scope, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package events
|
||||
|
||||
const (
|
||||
// KindAgentTurnStart is emitted when an agent turn starts.
|
||||
KindAgentTurnStart Kind = "agent.turn.start"
|
||||
// KindAgentTurnEnd is emitted when an agent turn ends.
|
||||
KindAgentTurnEnd Kind = "agent.turn.end"
|
||||
|
||||
// KindAgentLLMRequest is emitted before an LLM request.
|
||||
KindAgentLLMRequest Kind = "agent.llm.request"
|
||||
// KindAgentLLMDelta is emitted for streaming LLM deltas.
|
||||
KindAgentLLMDelta Kind = "agent.llm.delta"
|
||||
// KindAgentLLMResponse is emitted after an LLM response.
|
||||
KindAgentLLMResponse Kind = "agent.llm.response"
|
||||
// KindAgentLLMRetry is emitted before retrying an LLM request.
|
||||
KindAgentLLMRetry Kind = "agent.llm.retry"
|
||||
|
||||
// KindAgentContextCompress is emitted when agent context is compressed.
|
||||
KindAgentContextCompress Kind = "agent.context.compress"
|
||||
// KindAgentSessionSummarize is emitted when session summarization completes.
|
||||
KindAgentSessionSummarize Kind = "agent.session.summarize"
|
||||
|
||||
// KindAgentToolExecStart is emitted before a tool executes.
|
||||
KindAgentToolExecStart Kind = "agent.tool.exec_start"
|
||||
// KindAgentToolExecEnd is emitted after a tool finishes.
|
||||
KindAgentToolExecEnd Kind = "agent.tool.exec_end"
|
||||
// KindAgentToolExecSkipped is emitted when a tool call is skipped.
|
||||
KindAgentToolExecSkipped Kind = "agent.tool.exec_skipped"
|
||||
|
||||
// KindAgentSteeringInjected is emitted when steering is injected into context.
|
||||
KindAgentSteeringInjected Kind = "agent.steering.injected"
|
||||
// KindAgentFollowUpQueued is emitted when async follow-up input is queued.
|
||||
KindAgentFollowUpQueued Kind = "agent.follow_up.queued"
|
||||
// KindAgentInterruptReceived is emitted when a turn interrupt is accepted.
|
||||
KindAgentInterruptReceived Kind = "agent.interrupt.received"
|
||||
|
||||
// KindAgentSubTurnSpawn is emitted when a sub-turn is spawned.
|
||||
KindAgentSubTurnSpawn Kind = "agent.subturn.spawn"
|
||||
// KindAgentSubTurnEnd is emitted when a sub-turn ends.
|
||||
KindAgentSubTurnEnd Kind = "agent.subturn.end"
|
||||
// KindAgentSubTurnResultDelivered is emitted when a sub-turn result is delivered.
|
||||
KindAgentSubTurnResultDelivered Kind = "agent.subturn.result_delivered"
|
||||
// KindAgentSubTurnOrphan is emitted when a sub-turn result cannot be delivered.
|
||||
KindAgentSubTurnOrphan Kind = "agent.subturn.orphan"
|
||||
// KindAgentError is emitted when agent execution reports an error.
|
||||
KindAgentError Kind = "agent.error"
|
||||
|
||||
// KindChannelLifecycleStarted is emitted when a channel starts.
|
||||
KindChannelLifecycleStarted Kind = "channel.lifecycle.started"
|
||||
// KindChannelLifecycleInitialized is emitted when a channel is initialized.
|
||||
KindChannelLifecycleInitialized Kind = "channel.lifecycle.initialized"
|
||||
// KindChannelLifecycleStartFailed is emitted when a channel fails to start.
|
||||
KindChannelLifecycleStartFailed Kind = "channel.lifecycle.start_failed"
|
||||
// KindChannelLifecycleStopped is emitted when a channel stops.
|
||||
KindChannelLifecycleStopped Kind = "channel.lifecycle.stopped"
|
||||
// KindChannelWebhookRegistered is emitted when a channel webhook is registered.
|
||||
KindChannelWebhookRegistered Kind = "channel.webhook.registered"
|
||||
// KindChannelWebhookUnregistered is emitted when a channel webhook is unregistered.
|
||||
KindChannelWebhookUnregistered Kind = "channel.webhook.unregistered"
|
||||
// KindChannelMessageOutboundQueued is emitted when an outbound message is queued.
|
||||
KindChannelMessageOutboundQueued Kind = "channel.message.outbound_queued"
|
||||
// KindChannelMessageOutboundSent is emitted when an outbound channel message is sent.
|
||||
KindChannelMessageOutboundSent Kind = "channel.message.outbound_sent"
|
||||
// KindChannelMessageOutboundFailed is emitted when an outbound channel message fails.
|
||||
KindChannelMessageOutboundFailed Kind = "channel.message.outbound_failed"
|
||||
// KindChannelRateLimited is emitted when channel rate limiting blocks delivery.
|
||||
KindChannelRateLimited Kind = "channel.rate_limited"
|
||||
|
||||
// KindBusPublishFailed is emitted when message bus publish fails.
|
||||
KindBusPublishFailed Kind = "bus.publish.failed"
|
||||
// KindBusCloseStarted is emitted when message bus close starts.
|
||||
KindBusCloseStarted Kind = "bus.close.started"
|
||||
// KindBusCloseCompleted is emitted when message bus close completes.
|
||||
KindBusCloseCompleted Kind = "bus.close.completed"
|
||||
// KindBusCloseDrained is emitted when message bus close drains buffered messages.
|
||||
KindBusCloseDrained Kind = "bus.close.drained"
|
||||
|
||||
// KindGatewayStart is emitted when gateway startup reaches runtime bootstrap.
|
||||
KindGatewayStart Kind = "gateway.start"
|
||||
// KindGatewayReady is emitted when gateway services are started and ready.
|
||||
KindGatewayReady Kind = "gateway.ready"
|
||||
// KindGatewayShutdown is emitted when gateway shutdown starts.
|
||||
KindGatewayShutdown Kind = "gateway.shutdown"
|
||||
// KindGatewayReloadStarted is emitted when gateway reload starts.
|
||||
KindGatewayReloadStarted Kind = "gateway.reload.started"
|
||||
// KindGatewayReloadCompleted is emitted when gateway reload completes.
|
||||
KindGatewayReloadCompleted Kind = "gateway.reload.completed"
|
||||
// KindGatewayReloadFailed is emitted when gateway reload fails.
|
||||
KindGatewayReloadFailed Kind = "gateway.reload.failed"
|
||||
|
||||
// KindMCPServerConnected is emitted when an MCP server connects.
|
||||
KindMCPServerConnected Kind = "mcp.server.connected"
|
||||
// KindMCPServerConnecting is emitted before connecting to an MCP server.
|
||||
KindMCPServerConnecting Kind = "mcp.server.connecting"
|
||||
// KindMCPServerFailed is emitted when an MCP server fails.
|
||||
KindMCPServerFailed Kind = "mcp.server.failed"
|
||||
// KindMCPToolDiscovered is emitted when an MCP tool is discovered.
|
||||
KindMCPToolDiscovered Kind = "mcp.tool.discovered"
|
||||
// KindMCPToolCallStart is emitted when an MCP tool call starts.
|
||||
KindMCPToolCallStart Kind = "mcp.tool.call.start"
|
||||
// KindMCPToolCallEnd is emitted when an MCP tool call ends.
|
||||
KindMCPToolCallEnd Kind = "mcp.tool.call.end"
|
||||
)
|
||||
|
||||
var knownKinds = []Kind{
|
||||
KindAgentTurnStart,
|
||||
KindAgentTurnEnd,
|
||||
KindAgentLLMRequest,
|
||||
KindAgentLLMDelta,
|
||||
KindAgentLLMResponse,
|
||||
KindAgentLLMRetry,
|
||||
KindAgentContextCompress,
|
||||
KindAgentSessionSummarize,
|
||||
KindAgentToolExecStart,
|
||||
KindAgentToolExecEnd,
|
||||
KindAgentToolExecSkipped,
|
||||
KindAgentSteeringInjected,
|
||||
KindAgentFollowUpQueued,
|
||||
KindAgentInterruptReceived,
|
||||
KindAgentSubTurnSpawn,
|
||||
KindAgentSubTurnEnd,
|
||||
KindAgentSubTurnResultDelivered,
|
||||
KindAgentSubTurnOrphan,
|
||||
KindAgentError,
|
||||
KindChannelLifecycleStarted,
|
||||
KindChannelLifecycleInitialized,
|
||||
KindChannelLifecycleStartFailed,
|
||||
KindChannelLifecycleStopped,
|
||||
KindChannelWebhookRegistered,
|
||||
KindChannelWebhookUnregistered,
|
||||
KindChannelMessageOutboundQueued,
|
||||
KindChannelMessageOutboundSent,
|
||||
KindChannelMessageOutboundFailed,
|
||||
KindChannelRateLimited,
|
||||
KindBusPublishFailed,
|
||||
KindBusCloseStarted,
|
||||
KindBusCloseCompleted,
|
||||
KindBusCloseDrained,
|
||||
KindGatewayStart,
|
||||
KindGatewayReady,
|
||||
KindGatewayShutdown,
|
||||
KindGatewayReloadStarted,
|
||||
KindGatewayReloadCompleted,
|
||||
KindGatewayReloadFailed,
|
||||
KindMCPServerConnected,
|
||||
KindMCPServerConnecting,
|
||||
KindMCPServerFailed,
|
||||
KindMCPToolDiscovered,
|
||||
KindMCPToolCallStart,
|
||||
KindMCPToolCallEnd,
|
||||
}
|
||||
|
||||
// KnownKinds returns the runtime event kinds declared by this package.
|
||||
func KnownKinds() []Kind {
|
||||
return append([]Kind(nil), knownKinds...)
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package events
|
||||
|
||||
// Stats reports aggregate EventBus counters.
|
||||
type Stats struct {
|
||||
Published uint64
|
||||
Matched uint64
|
||||
Delivered uint64
|
||||
Dropped uint64
|
||||
Blocked uint64
|
||||
Closed bool
|
||||
Subscribers int
|
||||
|
||||
SubscriberStats []SubscriberStats
|
||||
}
|
||||
|
||||
// SubscriberStats reports counters for one subscription.
|
||||
type SubscriberStats struct {
|
||||
ID uint64
|
||||
Name string
|
||||
Received uint64
|
||||
Handled uint64
|
||||
Failed uint64
|
||||
Dropped uint64
|
||||
Panicked uint64
|
||||
TimedOut uint64
|
||||
}
|
||||
@@ -0,0 +1,459 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultSubscriberBuffer = 16
|
||||
|
||||
var (
|
||||
// ErrBusClosed is returned when subscribing to a closed event bus.
|
||||
ErrBusClosed = errors.New("events: bus is closed")
|
||||
// ErrNilHandler is returned when subscribing without a handler.
|
||||
ErrNilHandler = errors.New("events: handler is nil")
|
||||
)
|
||||
|
||||
// Handler processes a runtime event delivered to a subscription.
|
||||
type Handler func(context.Context, Event) error
|
||||
|
||||
// SubscribeOptions controls how a subscription receives events.
|
||||
type SubscribeOptions struct {
|
||||
Name string
|
||||
Buffer int
|
||||
Priority int
|
||||
Concurrency ConcurrencyKind
|
||||
Backpressure BackpressurePolicy
|
||||
// Timeout bounds how long the subscription worker waits for one handler call.
|
||||
// Handlers should still honor ctx cancellation; timed-out calls keep running
|
||||
// until their handler returns.
|
||||
Timeout time.Duration
|
||||
PanicPolicy PanicPolicy
|
||||
}
|
||||
|
||||
// ConcurrencyKind controls how handler subscriptions process queued events.
|
||||
type ConcurrencyKind string
|
||||
|
||||
const (
|
||||
// Concurrent processes each event in its own goroutine.
|
||||
Concurrent ConcurrencyKind = "concurrent"
|
||||
// Locked processes events sequentially in subscription order.
|
||||
Locked ConcurrencyKind = "locked"
|
||||
// Keyed is reserved for keyed sequential processing and currently behaves as Locked.
|
||||
Keyed ConcurrencyKind = "keyed"
|
||||
)
|
||||
|
||||
// BackpressurePolicy controls delivery when a subscription queue is full.
|
||||
type BackpressurePolicy string
|
||||
|
||||
const (
|
||||
// DropNewest drops the event being published when the queue is full.
|
||||
DropNewest BackpressurePolicy = "drop_newest"
|
||||
// DropOldest drops one queued event and enqueues the event being published.
|
||||
DropOldest BackpressurePolicy = "drop_oldest"
|
||||
// Block waits for queue capacity until Publish's context is canceled.
|
||||
Block BackpressurePolicy = "block"
|
||||
)
|
||||
|
||||
// PanicPolicy controls handler panic behavior.
|
||||
type PanicPolicy string
|
||||
|
||||
const (
|
||||
// RecoverAndLog recovers handler panics and records them in subscription stats.
|
||||
RecoverAndLog PanicPolicy = "recover_and_log"
|
||||
// Crash lets handler panics propagate from the worker goroutine.
|
||||
Crash PanicPolicy = "crash"
|
||||
)
|
||||
|
||||
// Subscription represents an active event subscription.
|
||||
type Subscription interface {
|
||||
ID() uint64
|
||||
Name() string
|
||||
Close() error
|
||||
Done() <-chan struct{}
|
||||
Stats() SubscriberStats
|
||||
}
|
||||
|
||||
type subscriberCounters struct {
|
||||
received atomic.Uint64
|
||||
handled atomic.Uint64
|
||||
failed atomic.Uint64
|
||||
dropped atomic.Uint64
|
||||
panicked atomic.Uint64
|
||||
timedOut atomic.Uint64
|
||||
}
|
||||
|
||||
type eventSubscription struct {
|
||||
bus *EventBus
|
||||
id uint64
|
||||
name string
|
||||
opts SubscribeOptions
|
||||
filters []Filter
|
||||
handler Handler
|
||||
once bool
|
||||
|
||||
ch chan Event
|
||||
done chan struct{}
|
||||
closing chan struct{}
|
||||
|
||||
closeOnce sync.Once
|
||||
doneOnce sync.Once
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
wg sync.WaitGroup
|
||||
blockWG sync.WaitGroup
|
||||
|
||||
counters subscriberCounters
|
||||
}
|
||||
|
||||
type handlerResult struct {
|
||||
err error
|
||||
panicked bool
|
||||
}
|
||||
|
||||
func normalizeSubscribeOptions(opts SubscribeOptions) SubscribeOptions {
|
||||
if opts.Buffer <= 0 {
|
||||
opts.Buffer = defaultSubscriberBuffer
|
||||
}
|
||||
if opts.Concurrency == "" {
|
||||
opts.Concurrency = Locked
|
||||
}
|
||||
if opts.Backpressure == "" {
|
||||
opts.Backpressure = DropNewest
|
||||
}
|
||||
if opts.PanicPolicy == "" {
|
||||
opts.PanicPolicy = RecoverAndLog
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func newSubscription(
|
||||
bus *EventBus,
|
||||
id uint64,
|
||||
filters []Filter,
|
||||
opts SubscribeOptions,
|
||||
handler Handler,
|
||||
once bool,
|
||||
) *eventSubscription {
|
||||
opts = normalizeSubscribeOptions(opts)
|
||||
return &eventSubscription{
|
||||
bus: bus,
|
||||
id: id,
|
||||
name: opts.Name,
|
||||
opts: opts,
|
||||
filters: append([]Filter(nil), filters...),
|
||||
handler: handler,
|
||||
once: once,
|
||||
ch: make(chan Event, opts.Buffer),
|
||||
done: make(chan struct{}),
|
||||
closing: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the subscription identifier.
|
||||
func (s *eventSubscription) ID() uint64 {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return s.id
|
||||
}
|
||||
|
||||
// Name returns the subscription name.
|
||||
func (s *eventSubscription) Name() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Close removes the subscription and closes its delivery channel.
|
||||
func (s *eventSubscription) Close() error {
|
||||
if s == nil || s.bus == nil {
|
||||
return nil
|
||||
}
|
||||
s.bus.unsubscribe(s.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel closed after the subscription has stopped processing.
|
||||
func (s *eventSubscription) Done() <-chan struct{} {
|
||||
if s == nil {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
return s.done
|
||||
}
|
||||
|
||||
// Stats returns a snapshot of the subscription counters.
|
||||
func (s *eventSubscription) Stats() SubscriberStats {
|
||||
if s == nil {
|
||||
return SubscriberStats{}
|
||||
}
|
||||
return SubscriberStats{
|
||||
ID: s.id,
|
||||
Name: s.name,
|
||||
Received: s.counters.received.Load(),
|
||||
Handled: s.counters.handled.Load(),
|
||||
Failed: s.counters.failed.Load(),
|
||||
Dropped: s.counters.dropped.Load(),
|
||||
Panicked: s.counters.panicked.Load(),
|
||||
TimedOut: s.counters.timedOut.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *eventSubscription) run(ctx context.Context) {
|
||||
defer func() {
|
||||
s.wg.Wait()
|
||||
s.closeDone()
|
||||
}()
|
||||
|
||||
for evt := range s.ch {
|
||||
s.dispatch(ctx, evt)
|
||||
if s.once {
|
||||
_ = s.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *eventSubscription) dispatch(ctx context.Context, evt Event) {
|
||||
switch s.opts.Concurrency {
|
||||
case Concurrent:
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.handle(ctx, evt)
|
||||
}()
|
||||
case Keyed:
|
||||
// TODO: replace this with keyed executors when runtime events need
|
||||
// per-scope ordering with cross-scope concurrency.
|
||||
s.handle(ctx, evt)
|
||||
default:
|
||||
s.handle(ctx, evt)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *eventSubscription) handle(ctx context.Context, evt Event) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if s.opts.Timeout <= 0 {
|
||||
s.recordHandlerResult(ctx, s.invokeHandler(ctx, evt))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, s.opts.Timeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan handlerResult, 1)
|
||||
go func() {
|
||||
done <- s.invokeHandler(ctx, evt)
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-done:
|
||||
s.recordHandlerResult(ctx, result)
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
s.counters.timedOut.Add(1)
|
||||
}
|
||||
s.counters.failed.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *eventSubscription) invokeHandler(ctx context.Context, evt Event) (result handlerResult) {
|
||||
if s.opts.PanicPolicy != Crash {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
s.counters.panicked.Add(1)
|
||||
result.panicked = true
|
||||
log.Printf("events: subscriber %q recovered panic: %v", s.name, recovered)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
result.err = s.handler(ctx, evt)
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *eventSubscription) recordHandlerResult(ctx context.Context, result handlerResult) {
|
||||
if result.panicked {
|
||||
return
|
||||
}
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
s.counters.timedOut.Add(1)
|
||||
}
|
||||
if result.err != nil {
|
||||
s.counters.failed.Add(1)
|
||||
return
|
||||
}
|
||||
s.counters.handled.Add(1)
|
||||
}
|
||||
|
||||
func (s *eventSubscription) watchContext(ctx context.Context) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = s.Close()
|
||||
case <-s.done:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *eventSubscription) closeInput() {
|
||||
s.closeOnce.Do(func() {
|
||||
close(s.closing)
|
||||
s.mu.Lock()
|
||||
s.closed = true
|
||||
s.mu.Unlock()
|
||||
s.blockWG.Wait()
|
||||
s.mu.Lock()
|
||||
close(s.ch)
|
||||
s.mu.Unlock()
|
||||
if s.handler == nil {
|
||||
s.closeDone()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *eventSubscription) closeDone() {
|
||||
s.doneOnce.Do(func() {
|
||||
close(s.done)
|
||||
})
|
||||
}
|
||||
|
||||
type deliveryResult struct {
|
||||
delivered int
|
||||
dropped int
|
||||
blocked int
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (s *eventSubscription) enqueue(ctx context.Context, evt Event, nonBlocking bool) deliveryResult {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if nonBlocking {
|
||||
return s.enqueueNonBlocking(evt)
|
||||
}
|
||||
|
||||
if s.opts.Backpressure == Block {
|
||||
return s.enqueueBlocking(ctx, evt)
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.closed {
|
||||
return deliveryResult{closed: true}
|
||||
}
|
||||
|
||||
s.counters.received.Add(1)
|
||||
|
||||
switch s.opts.Backpressure {
|
||||
case DropOldest:
|
||||
return s.enqueueDropOldest(evt)
|
||||
default:
|
||||
return s.enqueueDropNewest(evt)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *eventSubscription) enqueueBlocking(ctx context.Context, evt Event) deliveryResult {
|
||||
s.mu.Lock()
|
||||
if s.closed {
|
||||
s.mu.Unlock()
|
||||
return deliveryResult{closed: true}
|
||||
}
|
||||
s.blockWG.Add(1)
|
||||
s.counters.received.Add(1)
|
||||
s.mu.Unlock()
|
||||
|
||||
defer s.blockWG.Done()
|
||||
return s.enqueueBlock(ctx, evt)
|
||||
}
|
||||
|
||||
func (s *eventSubscription) enqueueNonBlocking(evt Event) deliveryResult {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.closed {
|
||||
return deliveryResult{closed: true}
|
||||
}
|
||||
|
||||
s.counters.received.Add(1)
|
||||
if s.opts.Backpressure == DropOldest {
|
||||
return s.enqueueDropOldest(evt)
|
||||
}
|
||||
return s.enqueueDropNewest(evt)
|
||||
}
|
||||
|
||||
func (s *eventSubscription) enqueueDropNewest(evt Event) deliveryResult {
|
||||
select {
|
||||
case <-s.closing:
|
||||
return deliveryResult{closed: true}
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case s.ch <- evt:
|
||||
return deliveryResult{delivered: 1}
|
||||
default:
|
||||
s.counters.dropped.Add(1)
|
||||
return deliveryResult{dropped: 1}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *eventSubscription) enqueueDropOldest(evt Event) deliveryResult {
|
||||
select {
|
||||
case <-s.closing:
|
||||
return deliveryResult{closed: true}
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case s.ch <- evt:
|
||||
return deliveryResult{delivered: 1}
|
||||
default:
|
||||
}
|
||||
|
||||
dropped := 0
|
||||
select {
|
||||
case <-s.ch:
|
||||
s.counters.dropped.Add(1)
|
||||
dropped = 1
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.closing:
|
||||
return deliveryResult{dropped: dropped, closed: true}
|
||||
case s.ch <- evt:
|
||||
return deliveryResult{delivered: 1, dropped: dropped}
|
||||
default:
|
||||
s.counters.dropped.Add(1)
|
||||
return deliveryResult{dropped: dropped + 1}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *eventSubscription) enqueueBlock(ctx context.Context, evt Event) deliveryResult {
|
||||
select {
|
||||
case <-s.closing:
|
||||
return deliveryResult{closed: true}
|
||||
case s.ch <- evt:
|
||||
return deliveryResult{delivered: 1}
|
||||
case <-ctx.Done():
|
||||
s.counters.dropped.Add(1)
|
||||
return deliveryResult{dropped: 1, blocked: 1}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSubscribeOnceClosesAfterFirstEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
var handled atomic.Uint64
|
||||
sub, err := bus.Channel().SubscribeOnce(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "once", Buffer: 2},
|
||||
func(context.Context, Event) error {
|
||||
handled.Add(1)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeOnce failed: %v", err)
|
||||
}
|
||||
|
||||
bus.Publish(context.Background(), Event{Kind: KindAgentTurnStart})
|
||||
waitForSubscriptionDone(t, sub)
|
||||
bus.Publish(context.Background(), Event{Kind: KindAgentTurnEnd})
|
||||
|
||||
if got := handled.Load(); got != 1 {
|
||||
t.Fatalf("handled = %d, want 1", got)
|
||||
}
|
||||
if got := sub.Stats().Handled; got != 1 {
|
||||
t.Fatalf("subscription handled = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribeClosesChannel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
sub, ch, err := bus.Channel().SubscribeChan(context.Background(), SubscribeOptions{Name: "chan"})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
if err := sub.Close(); err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if ok {
|
||||
t.Fatal("channel is open, want closed")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for channel close")
|
||||
}
|
||||
waitForSubscriptionDone(t, sub)
|
||||
}
|
||||
|
||||
func TestBlockBackpressureCloseUnblocksPublisher(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
sub, _, err := bus.Channel().SubscribeChan(context.Background(), SubscribeOptions{
|
||||
Name: "block-close",
|
||||
Buffer: 1,
|
||||
Backpressure: Block,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
first := bus.Publish(context.Background(), Event{Kind: Kind("test.first")})
|
||||
if first.Delivered != 1 {
|
||||
t.Fatalf("first Publish = %+v, want one delivered event", first)
|
||||
}
|
||||
|
||||
publishStarted := make(chan struct{})
|
||||
publishReturned := make(chan PublishResult, 1)
|
||||
go func() {
|
||||
close(publishStarted)
|
||||
publishReturned <- bus.Publish(context.Background(), Event{Kind: Kind("test.second")})
|
||||
}()
|
||||
|
||||
<-publishStarted
|
||||
waitForStat(t, func() uint64 {
|
||||
return sub.Stats().Received
|
||||
}, 2)
|
||||
select {
|
||||
case result := <-publishReturned:
|
||||
t.Fatalf("blocking Publish returned before close: %+v", result)
|
||||
default:
|
||||
}
|
||||
|
||||
closeReturned := make(chan error, 1)
|
||||
go func() {
|
||||
closeReturned <- sub.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-closeReturned:
|
||||
if err != nil {
|
||||
t.Fatalf("Close failed: %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for Close to unblock")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-publishReturned:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for blocking Publish to return after close")
|
||||
}
|
||||
waitForSubscriptionDone(t, sub)
|
||||
}
|
||||
|
||||
func TestHandlerPanicRecovered(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
sub, err := bus.Channel().Subscribe(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "panic", Buffer: 1},
|
||||
func(context.Context, Event) error {
|
||||
panic("boom")
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
|
||||
bus.Publish(context.Background(), Event{Kind: KindAgentError})
|
||||
waitForStat(t, func() uint64 {
|
||||
return sub.Stats().Panicked
|
||||
}, 1)
|
||||
}
|
||||
|
||||
func TestLockedHandlerProcessesSequentially(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
var active atomic.Int64
|
||||
var maxActive atomic.Int64
|
||||
sub, err := bus.Channel().Subscribe(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "locked", Buffer: 8, Concurrency: Locked},
|
||||
func(context.Context, Event) error {
|
||||
current := active.Add(1)
|
||||
for {
|
||||
currentMax := maxActive.Load()
|
||||
if current <= currentMax || maxActive.CompareAndSwap(currentMax, current) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
active.Add(-1)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
bus.Publish(context.Background(), Event{Kind: KindAgentLLMDelta})
|
||||
}
|
||||
waitForStat(t, func() uint64 {
|
||||
return sub.Stats().Handled
|
||||
}, 5)
|
||||
|
||||
if got := maxActive.Load(); got != 1 {
|
||||
t.Fatalf("max active handlers = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerTimeoutDoesNotWedgeLockedSubscription(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bus := NewBus()
|
||||
defer closeBus(t, bus)
|
||||
|
||||
releaseFirst := make(chan struct{})
|
||||
defer close(releaseFirst)
|
||||
|
||||
var calls atomic.Uint64
|
||||
sub, err := bus.Channel().Subscribe(
|
||||
context.Background(),
|
||||
SubscribeOptions{Name: "timeout", Buffer: 2, Concurrency: Locked, Timeout: 20 * time.Millisecond},
|
||||
func(context.Context, Event) error {
|
||||
if calls.Add(1) == 1 {
|
||||
<-releaseFirst
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Subscribe failed: %v", err)
|
||||
}
|
||||
|
||||
bus.Publish(context.Background(), Event{Kind: Kind("test.first")})
|
||||
waitForStat(t, func() uint64 {
|
||||
return sub.Stats().TimedOut
|
||||
}, 1)
|
||||
|
||||
bus.Publish(context.Background(), Event{Kind: Kind("test.second")})
|
||||
waitForStat(t, func() uint64 {
|
||||
return sub.Stats().Handled
|
||||
}, 1)
|
||||
|
||||
if got := sub.Stats().Failed; got != 1 {
|
||||
t.Fatalf("subscription failed = %d, want timeout failure", got)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForSubscriptionDone(t *testing.T, sub Subscription) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-sub.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for subscription to stop")
|
||||
}
|
||||
}
|
||||
|
||||
func waitForStat(t *testing.T, stat func() uint64, want uint64) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.After(time.Second)
|
||||
ticker := time.NewTicker(time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if got := stat(); got >= want {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-deadline:
|
||||
t.Fatalf("timed out waiting for stat >= %d", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package events
|
||||
|
||||
import "time"
|
||||
|
||||
// Kind identifies a runtime event category.
|
||||
type Kind string
|
||||
|
||||
// String returns the string representation of the event kind.
|
||||
func (k Kind) String() string {
|
||||
return string(k)
|
||||
}
|
||||
|
||||
// Event is the runtime event envelope shared across PicoClaw components.
|
||||
type Event struct {
|
||||
ID string `json:"id"`
|
||||
Kind Kind `json:"kind"`
|
||||
Time time.Time `json:"time"`
|
||||
Source Source `json:"source"`
|
||||
Scope Scope `json:"scope,omitempty"`
|
||||
Correlation Correlation `json:"correlation,omitempty"`
|
||||
Severity Severity `json:"severity,omitempty"`
|
||||
Payload any `json:"payload,omitempty"`
|
||||
Attrs map[string]any `json:"attrs,omitempty"`
|
||||
}
|
||||
|
||||
// Source identifies the component that emitted an event.
|
||||
type Source struct {
|
||||
Component string `json:"component"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// Scope identifies the runtime ownership of an event.
|
||||
//
|
||||
// Scope is intentionally limited to agent, session, turn, channel, chat,
|
||||
// message, and sender identity. Tool, provider, model, and MCP details belong
|
||||
// in Source, Payload, or Attrs.
|
||||
type Scope struct {
|
||||
RuntimeID string `json:"runtime_id,omitempty"`
|
||||
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
SessionKey string `json:"session_key,omitempty"`
|
||||
TurnID string `json:"turn_id,omitempty"`
|
||||
|
||||
Channel string `json:"channel,omitempty"`
|
||||
Account string `json:"account,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
TopicID string `json:"topic_id,omitempty"`
|
||||
|
||||
SpaceID string `json:"space_id,omitempty"`
|
||||
SpaceType string `json:"space_type,omitempty"`
|
||||
ChatType string `json:"chat_type,omitempty"`
|
||||
|
||||
SenderID string `json:"sender_id,omitempty"`
|
||||
MessageID string `json:"message_id,omitempty"`
|
||||
}
|
||||
|
||||
// Correlation carries cross-event tracing fields.
|
||||
type Correlation struct {
|
||||
TraceID string `json:"trace_id,omitempty"`
|
||||
ParentTurnID string `json:"parent_turn_id,omitempty"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
ReplyToID string `json:"reply_to_id,omitempty"`
|
||||
}
|
||||
|
||||
// Severity describes the operational severity of an event.
|
||||
type Severity string
|
||||
|
||||
const (
|
||||
// SeverityDebug is used for verbose diagnostic events.
|
||||
SeverityDebug Severity = "debug"
|
||||
// SeverityInfo is used for normal lifecycle and activity events.
|
||||
SeverityInfo Severity = "info"
|
||||
// SeverityWarn is used for recoverable abnormal events.
|
||||
SeverityWarn Severity = "warn"
|
||||
// SeverityError is used for failed operations and unrecoverable events.
|
||||
SeverityError Severity = "error"
|
||||
)
|
||||
@@ -0,0 +1,53 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
type gatewayEventPayload struct {
|
||||
DurationMS int64 `json:"duration_ms,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func publishGatewayEvent(
|
||||
al *agent.AgentLoop,
|
||||
kind runtimeevents.Kind,
|
||||
startedAt time.Time,
|
||||
err error,
|
||||
) {
|
||||
if al == nil || al.RuntimeEventBus() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
severity := runtimeevents.SeverityInfo
|
||||
payload := gatewayEventPayload{}
|
||||
if !startedAt.IsZero() {
|
||||
payload.DurationMS = time.Since(startedAt).Milliseconds()
|
||||
}
|
||||
if err != nil {
|
||||
severity = runtimeevents.SeverityError
|
||||
payload.Error = err.Error()
|
||||
}
|
||||
|
||||
al.RuntimeEventBus().PublishNonBlocking(runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "gateway"},
|
||||
Severity: severity,
|
||||
Payload: payload,
|
||||
Attrs: gatewayEventAttrs(payload),
|
||||
})
|
||||
}
|
||||
|
||||
func gatewayEventAttrs(payload gatewayEventPayload) map[string]any {
|
||||
attrs := map[string]any{}
|
||||
if payload.DurationMS > 0 {
|
||||
attrs["duration_ms"] = payload.DurationMS
|
||||
}
|
||||
if payload.Error != "" {
|
||||
attrs["error"] = payload.Error
|
||||
}
|
||||
return attrs
|
||||
}
|
||||
+31
-4
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/devices"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -114,6 +115,7 @@ func (p *startupBlockedProvider) GetDefaultModel() string {
|
||||
|
||||
// Run starts the gateway runtime using the configuration loaded from configPath.
|
||||
func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runErr error) {
|
||||
startedAt := time.Now()
|
||||
panicPath := filepath.Join(homePath, logPath, panicFile)
|
||||
panicFunc, err := logger.InitPanic(panicPath)
|
||||
if err != nil {
|
||||
@@ -197,6 +199,8 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
msgBus.SetEventPublisher(agentLoop.RuntimeEventBus())
|
||||
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayStart, startedAt, nil)
|
||||
|
||||
fmt.Println("\n📦 Agent Status:")
|
||||
startupInfo := agentLoop.GetStartupInfo()
|
||||
@@ -216,6 +220,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayReady, startedAt, nil)
|
||||
closeListeners = false
|
||||
|
||||
// Setup manual reload channel for /reload endpoint
|
||||
@@ -262,7 +267,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
select {
|
||||
case <-sigChan:
|
||||
logger.Info("Shutting down...")
|
||||
shutdownGateway(runningServices, agentLoop, provider, true)
|
||||
shutdownGateway(runningServices, agentLoop, provider, msgBus, true)
|
||||
return nil
|
||||
case newCfg := <-configReloadChan:
|
||||
if !runningServices.reloading.CompareAndSwap(false, true) {
|
||||
@@ -312,10 +317,20 @@ func executeReload(
|
||||
msgBus *bus.MessageBus,
|
||||
allowEmptyStartup bool,
|
||||
debug bool,
|
||||
) error {
|
||||
) (err error) {
|
||||
startedAt := time.Now()
|
||||
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayReloadStarted, startedAt, nil)
|
||||
defer runningServices.reloading.Store(false)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayReloadFailed, startedAt, err)
|
||||
return
|
||||
}
|
||||
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayReloadCompleted, startedAt, nil)
|
||||
}()
|
||||
|
||||
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
|
||||
err = handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
|
||||
return err
|
||||
}
|
||||
|
||||
func createStartupProvider(
|
||||
@@ -383,7 +398,12 @@ func setupAndStartServices(
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
runningServices.ChannelManager, err = channels.NewManager(
|
||||
cfg,
|
||||
msgBus,
|
||||
runningServices.MediaStore,
|
||||
channels.WithRuntimeEvents(agentLoop.RuntimeEventBus()),
|
||||
)
|
||||
if err != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
@@ -490,14 +510,21 @@ func shutdownGateway(
|
||||
runningServices *services,
|
||||
agentLoop *agent.AgentLoop,
|
||||
provider providers.LLMProvider,
|
||||
msgBus *bus.MessageBus,
|
||||
fullShutdown bool,
|
||||
) {
|
||||
publishGatewayEvent(agentLoop, runtimeevents.KindGatewayShutdown, time.Time{}, nil)
|
||||
|
||||
if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown {
|
||||
cp.Close()
|
||||
}
|
||||
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout, false)
|
||||
|
||||
if fullShutdown && msgBus != nil {
|
||||
msgBus.Close()
|
||||
}
|
||||
|
||||
agentLoop.Stop()
|
||||
agentLoop.Close()
|
||||
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func TestRun_StartupFailuresReturnErrorAndEmitStructuredLog(t *testing.T) {
|
||||
@@ -106,3 +112,100 @@ func TestGatewayRunStartupFailureHelper(t *testing.T) {
|
||||
fmt.Fprintln(os.Stdout, err.Error())
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func TestPublishGatewayEvent(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
t.Cleanup(func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Fatalf("Close runtime event bus: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
sub, eventsCh, err := eventBus.Channel().OfKind(runtimeevents.KindGatewayStart).SubscribeChan(
|
||||
ctx,
|
||||
runtimeevents.SubscribeOptions{Name: "gateway-test", Buffer: 4},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := sub.Close(); err != nil {
|
||||
t.Fatalf("Close subscription: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
al := agent.NewAgentLoop(
|
||||
config.DefaultConfig(),
|
||||
bus.NewMessageBus(),
|
||||
&startupBlockedProvider{reason: "not used"},
|
||||
agent.WithRuntimeEvents(eventBus),
|
||||
)
|
||||
t.Cleanup(al.Close)
|
||||
|
||||
startedAt := time.Now().Add(-1500 * time.Millisecond)
|
||||
publishGatewayEvent(al, runtimeevents.KindGatewayStart, startedAt, nil)
|
||||
|
||||
evt := receiveGatewayRuntimeEvent(t, eventsCh)
|
||||
if evt.Kind != runtimeevents.KindGatewayStart ||
|
||||
evt.Source.Component != "gateway" ||
|
||||
evt.Severity != runtimeevents.SeverityInfo {
|
||||
t.Fatalf("gateway event = %+v", evt)
|
||||
}
|
||||
payload, ok := evt.Payload.(gatewayEventPayload)
|
||||
if !ok {
|
||||
t.Fatalf("payload type = %T, want gatewayEventPayload", evt.Payload)
|
||||
}
|
||||
if payload.DurationMS <= 0 {
|
||||
t.Fatalf("DurationMS = %d, want positive", payload.DurationMS)
|
||||
}
|
||||
if evt.Attrs["duration_ms"] == nil {
|
||||
t.Fatalf("gateway event attrs missing duration_ms: %#v", evt.Attrs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownGatewayClosesMessageBus(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := agent.NewAgentLoop(
|
||||
config.DefaultConfig(),
|
||||
msgBus,
|
||||
&startupBlockedProvider{reason: "not used"},
|
||||
)
|
||||
msgBus.SetEventPublisher(al.RuntimeEventBus())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
sub, eventsCh, err := al.RuntimeEventBus().Channel().OfKind(runtimeevents.KindBusCloseCompleted).SubscribeChan(
|
||||
ctx,
|
||||
runtimeevents.SubscribeOptions{Name: "bus-close-test", Buffer: 4},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = sub.Close()
|
||||
}()
|
||||
|
||||
shutdownGateway(&services{}, al, &startupBlockedProvider{reason: "not used"}, msgBus, true)
|
||||
|
||||
evt := receiveGatewayRuntimeEvent(t, eventsCh)
|
||||
if evt.Kind != runtimeevents.KindBusCloseCompleted {
|
||||
t.Fatalf("shutdown event kind = %q, want %q", evt.Kind, runtimeevents.KindBusCloseCompleted)
|
||||
}
|
||||
if err := msgBus.PublishVoiceControl(context.Background(), bus.VoiceControl{}); !errors.Is(err, bus.ErrBusClosed) {
|
||||
t.Fatalf("PublishVoiceControl after shutdown error = %v, want %v", err, bus.ErrBusClosed)
|
||||
}
|
||||
}
|
||||
|
||||
func receiveGatewayRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt := <-ch:
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for gateway runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func (m *Manager) publishServerEvent(
|
||||
kind runtimeevents.Kind,
|
||||
serverName string,
|
||||
cfg config.MCPServerConfig,
|
||||
toolCount int,
|
||||
err error,
|
||||
) {
|
||||
if m == nil || m.runtimeEvents == nil {
|
||||
return
|
||||
}
|
||||
|
||||
severity := runtimeevents.SeverityInfo
|
||||
if err != nil {
|
||||
severity = runtimeevents.SeverityError
|
||||
}
|
||||
payload := ServerEventPayload{
|
||||
Server: serverName,
|
||||
Type: mcpTransportType(cfg),
|
||||
URL: cfg.URL,
|
||||
Command: cfg.Command,
|
||||
ToolCount: toolCount,
|
||||
}
|
||||
if err != nil {
|
||||
payload.Error = err.Error()
|
||||
}
|
||||
|
||||
m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "mcp", Name: serverName},
|
||||
Severity: severity,
|
||||
Payload: payload,
|
||||
Attrs: mcpServerEventAttrs(payload),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) publishToolDiscovered(serverName string, cfg config.MCPServerConfig, toolName string) {
|
||||
if m == nil || m.runtimeEvents == nil {
|
||||
return
|
||||
}
|
||||
payload := ServerEventPayload{
|
||||
Server: serverName,
|
||||
Type: mcpTransportType(cfg),
|
||||
URL: cfg.URL,
|
||||
Command: cfg.Command,
|
||||
Tool: toolName,
|
||||
}
|
||||
m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{
|
||||
Kind: runtimeevents.KindMCPToolDiscovered,
|
||||
Source: runtimeevents.Source{Component: "mcp", Name: serverName},
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
Payload: payload,
|
||||
Attrs: mcpServerEventAttrs(payload),
|
||||
})
|
||||
}
|
||||
|
||||
func mcpServerEventAttrs(payload ServerEventPayload) map[string]any {
|
||||
attrs := map[string]any{}
|
||||
setMCPAttrString(attrs, "server", payload.Server)
|
||||
setMCPAttrString(attrs, "type", payload.Type)
|
||||
setMCPAttrString(attrs, "tool", payload.Tool)
|
||||
if payload.ToolCount > 0 {
|
||||
attrs["tool_count"] = payload.ToolCount
|
||||
}
|
||||
setMCPAttrString(attrs, "error", payload.Error)
|
||||
return attrs
|
||||
}
|
||||
|
||||
func setMCPAttrString(attrs map[string]any, key, value string) {
|
||||
if value != "" {
|
||||
attrs[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func mcpTransportType(cfg config.MCPServerConfig) string {
|
||||
if cfg.Type != "" {
|
||||
return cfg.Type
|
||||
}
|
||||
if cfg.URL != "" {
|
||||
return "sse"
|
||||
}
|
||||
if cfg.Command != "" {
|
||||
return "stdio"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
+46
-6
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -127,19 +128,47 @@ type ServerConnection struct {
|
||||
|
||||
// Manager manages multiple MCP server connections
|
||||
type Manager struct {
|
||||
servers map[string]*ServerConnection
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
|
||||
wg sync.WaitGroup // tracks in-flight CallTool calls
|
||||
servers map[string]*ServerConnection
|
||||
runtimeEvents runtimeevents.Bus
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
|
||||
wg sync.WaitGroup // tracks in-flight CallTool calls
|
||||
}
|
||||
|
||||
var connectServerFunc = connectServer
|
||||
|
||||
// ManagerOption configures an MCP manager.
|
||||
type ManagerOption func(*Manager)
|
||||
|
||||
// WithRuntimeEvents injects the runtime event bus used for MCP observations.
|
||||
func WithRuntimeEvents(eventBus runtimeevents.Bus) ManagerOption {
|
||||
return func(m *Manager) {
|
||||
m.runtimeEvents = eventBus
|
||||
}
|
||||
}
|
||||
|
||||
// ServerEventPayload describes MCP server connection events.
|
||||
type ServerEventPayload struct {
|
||||
Server string `json:"server"`
|
||||
Type string `json:"type,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Tool string `json:"tool,omitempty"`
|
||||
ToolCount int `json:"tool_count,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP manager
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
func NewManager(opts ...ManagerOption) *Manager {
|
||||
m := &Manager{
|
||||
servers: make(map[string]*ServerConnection),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(m)
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// LoadFromConfig loads MCP servers from configuration
|
||||
@@ -264,8 +293,10 @@ func (m *Manager) ConnectServer(
|
||||
name string,
|
||||
cfg config.MCPServerConfig,
|
||||
) error {
|
||||
m.publishServerEvent(runtimeevents.KindMCPServerConnecting, name, cfg, 0, nil)
|
||||
conn, err := connectServerFunc(ctx, name, cfg)
|
||||
if err != nil {
|
||||
m.publishServerEvent(runtimeevents.KindMCPServerFailed, name, cfg, 0, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -274,10 +305,19 @@ func (m *Manager) ConnectServer(
|
||||
|
||||
if m.closed.Load() {
|
||||
_ = conn.Session.Close()
|
||||
m.publishServerEvent(runtimeevents.KindMCPServerFailed, name, cfg, 0, fmt.Errorf("manager is closed"))
|
||||
return fmt.Errorf("manager is closed")
|
||||
}
|
||||
|
||||
m.servers[name] = conn
|
||||
for _, tool := range conn.Tools {
|
||||
toolName := ""
|
||||
if tool != nil {
|
||||
toolName = tool.Name
|
||||
}
|
||||
m.publishToolDiscovered(name, cfg, toolName)
|
||||
}
|
||||
m.publishServerEvent(runtimeevents.KindMCPServerConnected, name, cfg, len(conn.Tools), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,11 +10,13 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
@@ -248,6 +250,95 @@ func TestNewManager_InitialState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectServerPublishesRuntimeEvents(t *testing.T) {
|
||||
originalConnectServerFunc := connectServerFunc
|
||||
t.Cleanup(func() {
|
||||
connectServerFunc = originalConnectServerFunc
|
||||
})
|
||||
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().OfKind(
|
||||
runtimeevents.KindMCPServerConnected,
|
||||
runtimeevents.KindMCPServerFailed,
|
||||
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "mcp-events", Buffer: 2})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
connectServerFunc = func(
|
||||
_ context.Context,
|
||||
name string,
|
||||
cfg config.MCPServerConfig,
|
||||
) (*ServerConnection, error) {
|
||||
if name == "bad" {
|
||||
return nil, fmt.Errorf("connect failed")
|
||||
}
|
||||
return &ServerConnection{
|
||||
Name: name,
|
||||
Config: cfg,
|
||||
Tools: []*sdkmcp.Tool{{Name: "echo"}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mgr := NewManager(WithRuntimeEvents(eventBus))
|
||||
err = mgr.ConnectServer(context.Background(), "good", config.MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: "echo",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConnectServer(good) error = %v", err)
|
||||
}
|
||||
connected := receiveMCPRuntimeEvent(t, eventsCh)
|
||||
if connected.Kind != runtimeevents.KindMCPServerConnected ||
|
||||
connected.Source.Name != "good" ||
|
||||
connected.Severity != runtimeevents.SeverityInfo {
|
||||
t.Fatalf("connected event = %+v", connected)
|
||||
}
|
||||
if connected.Attrs["server"] != "good" ||
|
||||
connected.Attrs["type"] != "stdio" ||
|
||||
connected.Attrs["tool_count"] != 1 {
|
||||
t.Fatalf("connected attrs = %#v", connected.Attrs)
|
||||
}
|
||||
|
||||
err = mgr.ConnectServer(context.Background(), "bad", config.MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: "echo",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected ConnectServer(bad) to fail")
|
||||
}
|
||||
failed := receiveMCPRuntimeEvent(t, eventsCh)
|
||||
if failed.Kind != runtimeevents.KindMCPServerFailed ||
|
||||
failed.Source.Name != "bad" ||
|
||||
failed.Severity != runtimeevents.SeverityError {
|
||||
t.Fatalf("failed event = %+v", failed)
|
||||
}
|
||||
if failed.Attrs["server"] != "bad" || failed.Attrs["error"] != "connect failed" {
|
||||
t.Fatalf("failed attrs = %#v", failed.Attrs)
|
||||
}
|
||||
}
|
||||
|
||||
func receiveMCPRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event channel closed before expected event")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
|
||||
@@ -0,0 +1,642 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxGeminiSchemaDepth = 64
|
||||
|
||||
var geminiSupportedTypes = map[string]bool{
|
||||
"array": true,
|
||||
"boolean": true,
|
||||
"integer": true,
|
||||
"number": true,
|
||||
"object": true,
|
||||
"string": true,
|
||||
}
|
||||
|
||||
// SanitizeSchemaForGoogle reduces a JSON Schema to the conservative subset
|
||||
// accepted by Google/Gemini-style function declarations. It resolves local
|
||||
// refs, collapses composition keywords like anyOf/oneOf/allOf, and strips
|
||||
// advanced keywords that Gemini-compatible backends often reject.
|
||||
func SanitizeSchemaForGoogle(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sanitizer := geminiSchemaSanitizer{root: schema}
|
||||
result := sanitizer.sanitizeNode(schema, nil, 0)
|
||||
if len(result) == 0 {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
if _, hasProps := result["properties"]; hasProps {
|
||||
result["type"] = "object"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SanitizeSchemaForGemini is kept as a compatibility alias for the original
|
||||
// Google/Gemini sanitizer name.
|
||||
func SanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
return SanitizeSchemaForGoogle(schema)
|
||||
}
|
||||
|
||||
type geminiSchemaSanitizer struct {
|
||||
root map[string]any
|
||||
}
|
||||
|
||||
func (s geminiSchemaSanitizer) sanitizeNode(
|
||||
node map[string]any,
|
||||
refTrail map[string]struct{},
|
||||
depth int,
|
||||
) map[string]any {
|
||||
if node == nil || depth > maxGeminiSchemaDepth {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
normalized := s.normalizeNode(node, refTrail, depth)
|
||||
if len(normalized) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
if desc, ok := normalized["description"].(string); ok && strings.TrimSpace(desc) != "" {
|
||||
result["description"] = desc
|
||||
}
|
||||
|
||||
if schemaType := sanitizeGeminiSchemaType(normalized["type"]); schemaType != "" {
|
||||
result["type"] = schemaType
|
||||
}
|
||||
|
||||
if enumValues := sanitizeGeminiEnum(normalized["enum"]); len(enumValues) > 0 {
|
||||
result["enum"] = enumValues
|
||||
}
|
||||
|
||||
if propsRaw, ok := normalized["properties"].(map[string]any); ok {
|
||||
props := make(map[string]any, len(propsRaw))
|
||||
for name, rawProp := range propsRaw {
|
||||
propSchema, ok := rawProp.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
sanitizedProp := s.sanitizeNode(propSchema, refTrail, depth+1)
|
||||
if len(sanitizedProp) == 0 {
|
||||
sanitizedProp = map[string]any{}
|
||||
}
|
||||
props[name] = sanitizedProp
|
||||
}
|
||||
result["properties"] = props
|
||||
result["type"] = "object"
|
||||
if required := sanitizeGeminiRequired(normalized["required"], props); len(required) > 0 {
|
||||
result["required"] = required
|
||||
}
|
||||
}
|
||||
|
||||
if itemsRaw, ok := normalized["items"].(map[string]any); ok {
|
||||
items := s.sanitizeNode(itemsRaw, refTrail, depth+1)
|
||||
if len(items) == 0 {
|
||||
items = map[string]any{}
|
||||
}
|
||||
result["items"] = items
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "array"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s geminiSchemaSanitizer) normalizeNode(
|
||||
node map[string]any,
|
||||
refTrail map[string]struct{},
|
||||
depth int,
|
||||
) map[string]any {
|
||||
if node == nil || depth > maxGeminiSchemaDepth {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
normalized := cloneGeminiSchemaMap(node)
|
||||
|
||||
if ref, ok := normalized["$ref"].(string); ok {
|
||||
delete(normalized, "$ref")
|
||||
if _, seen := refTrail[ref]; !seen {
|
||||
if target, ok := s.resolveLocalSchemaRef(ref); ok {
|
||||
nextTrail := cloneRefTrail(refTrail)
|
||||
nextTrail[ref] = struct{}{}
|
||||
normalized = mergeGeminiSchemaMaps(
|
||||
s.normalizeNode(target, nextTrail, depth+1),
|
||||
normalized,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rawAllOf, ok := normalized["allOf"]; ok {
|
||||
delete(normalized, "allOf")
|
||||
for _, part := range schemaSlice(rawAllOf) {
|
||||
normalized = mergeGeminiSchemaMaps(
|
||||
normalized,
|
||||
s.normalizeNode(part, refTrail, depth+1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if rawAnyOf, ok := normalized["anyOf"]; ok {
|
||||
delete(normalized, "anyOf")
|
||||
normalized = mergeGeminiSchemaMaps(
|
||||
s.mergeUnionBranches(schemaSlice(rawAnyOf), refTrail, depth+1),
|
||||
normalized,
|
||||
)
|
||||
}
|
||||
|
||||
if rawOneOf, ok := normalized["oneOf"]; ok {
|
||||
delete(normalized, "oneOf")
|
||||
normalized = mergeGeminiSchemaMaps(
|
||||
s.mergeUnionBranches(schemaSlice(rawOneOf), refTrail, depth+1),
|
||||
normalized,
|
||||
)
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
func (s geminiSchemaSanitizer) mergeUnionBranches(
|
||||
branches []map[string]any,
|
||||
refTrail map[string]struct{},
|
||||
depth int,
|
||||
) map[string]any {
|
||||
if len(branches) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
objectBranches := make([]map[string]any, 0, len(branches))
|
||||
arrayBranches := make([]map[string]any, 0, len(branches))
|
||||
nonNullBranches := make([]map[string]any, 0, len(branches))
|
||||
sameType := ""
|
||||
sameTypeConsistent := true
|
||||
|
||||
for _, branch := range branches {
|
||||
normalized := s.normalizeNode(branch, refTrail, depth+1)
|
||||
if len(normalized) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
branchType := geminiSchemaBranchType(normalized["type"])
|
||||
if branchType == "null" {
|
||||
continue
|
||||
}
|
||||
nonNullBranches = append(nonNullBranches, normalized)
|
||||
|
||||
if sameType == "" {
|
||||
sameType = branchType
|
||||
} else if branchType != "" && branchType != sameType {
|
||||
sameTypeConsistent = false
|
||||
}
|
||||
|
||||
if branchType == "object" || hasSchemaProperties(normalized) {
|
||||
objectBranches = append(objectBranches, normalized)
|
||||
continue
|
||||
}
|
||||
if branchType == "array" || hasSchemaItems(normalized) {
|
||||
arrayBranches = append(arrayBranches, normalized)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonNullBranches) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
if len(objectBranches) > 0 {
|
||||
return mergeUnionObjectSchemas(objectBranches)
|
||||
}
|
||||
if len(arrayBranches) == len(nonNullBranches) && len(arrayBranches) > 0 {
|
||||
return mergeUnionArraySchemas(arrayBranches)
|
||||
}
|
||||
if sameTypeConsistent && sameType != "" {
|
||||
merged := map[string]any{}
|
||||
for _, branch := range nonNullBranches {
|
||||
merged = mergeGeminiSchemaMaps(merged, branch)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
best := nonNullBranches[0]
|
||||
bestScore := geminiUnionBranchScore(best)
|
||||
for _, branch := range nonNullBranches[1:] {
|
||||
if score := geminiUnionBranchScore(branch); score > bestScore {
|
||||
best = branch
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
return cloneGeminiSchemaMap(best)
|
||||
}
|
||||
|
||||
func (s geminiSchemaSanitizer) resolveLocalSchemaRef(ref string) (map[string]any, bool) {
|
||||
if ref == "#" {
|
||||
return s.root, true
|
||||
}
|
||||
if !strings.HasPrefix(ref, "#/") {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var current any = s.root
|
||||
for _, rawToken := range strings.Split(strings.TrimPrefix(ref, "#/"), "/") {
|
||||
token := strings.ReplaceAll(strings.ReplaceAll(rawToken, "~1", "/"), "~0", "~")
|
||||
switch value := current.(type) {
|
||||
case map[string]any:
|
||||
next, ok := value[token]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
current = next
|
||||
case []any:
|
||||
index, err := strconv.Atoi(token)
|
||||
if err != nil || index < 0 || index >= len(value) {
|
||||
return nil, false
|
||||
}
|
||||
current = value[index]
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
resolved, ok := current.(map[string]any)
|
||||
return resolved, ok
|
||||
}
|
||||
|
||||
func mergeUnionObjectSchemas(branches []map[string]any) map[string]any {
|
||||
merged := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
|
||||
var commonRequired map[string]struct{}
|
||||
var requiredOrder []string
|
||||
|
||||
for i, branch := range branches {
|
||||
merged = mergeGeminiSchemaMaps(merged, branch)
|
||||
|
||||
required := requiredStrings(branch["required"])
|
||||
if i == 0 {
|
||||
commonRequired = make(map[string]struct{}, len(required))
|
||||
requiredOrder = append(requiredOrder, required...)
|
||||
for _, name := range required {
|
||||
commonRequired[name] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
current := make(map[string]struct{}, len(required))
|
||||
for _, name := range required {
|
||||
current[name] = struct{}{}
|
||||
}
|
||||
for name := range commonRequired {
|
||||
if _, ok := current[name]; !ok {
|
||||
delete(commonRequired, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(commonRequired) > 0 {
|
||||
filtered := make([]string, 0, len(commonRequired))
|
||||
for _, name := range requiredOrder {
|
||||
if _, ok := commonRequired[name]; ok {
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
}
|
||||
if len(filtered) > 0 {
|
||||
merged["required"] = filtered
|
||||
}
|
||||
} else {
|
||||
delete(merged, "required")
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func mergeUnionArraySchemas(branches []map[string]any) map[string]any {
|
||||
merged := map[string]any{
|
||||
"type": "array",
|
||||
}
|
||||
for _, branch := range branches {
|
||||
merged = mergeGeminiSchemaMaps(merged, branch)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func mergeGeminiSchemaMaps(base map[string]any, overlay map[string]any) map[string]any {
|
||||
if len(base) == 0 {
|
||||
return cloneGeminiSchemaMap(overlay)
|
||||
}
|
||||
if len(overlay) == 0 {
|
||||
return cloneGeminiSchemaMap(base)
|
||||
}
|
||||
|
||||
result := cloneGeminiSchemaMap(base)
|
||||
for key, value := range overlay {
|
||||
switch key {
|
||||
case "properties":
|
||||
overlayProps, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
existing, _ := result["properties"].(map[string]any)
|
||||
mergedProps := cloneGeminiSchemaMap(existing)
|
||||
if mergedProps == nil {
|
||||
mergedProps = make(map[string]any, len(overlayProps))
|
||||
}
|
||||
for name, rawProp := range overlayProps {
|
||||
propSchema, ok := rawProp.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if existingProp, ok := mergedProps[name].(map[string]any); ok {
|
||||
mergedProps[name] = mergeGeminiSchemaMaps(existingProp, propSchema)
|
||||
} else {
|
||||
mergedProps[name] = cloneGeminiSchemaMap(propSchema)
|
||||
}
|
||||
}
|
||||
result["properties"] = mergedProps
|
||||
case "items":
|
||||
overlayItems, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if existingItems, ok := result["items"].(map[string]any); ok {
|
||||
result["items"] = mergeGeminiSchemaMaps(existingItems, overlayItems)
|
||||
} else {
|
||||
result["items"] = cloneGeminiSchemaMap(overlayItems)
|
||||
}
|
||||
case "required":
|
||||
if merged := mergeRequiredLists(result["required"], value); len(merged) > 0 {
|
||||
result["required"] = merged
|
||||
}
|
||||
case "type":
|
||||
if mergedType := mergeGeminiSchemaTypes(result["type"], value); mergedType != "" {
|
||||
result["type"] = mergedType
|
||||
} else {
|
||||
delete(result, "type")
|
||||
}
|
||||
case "description":
|
||||
desc, ok := value.(string)
|
||||
if ok && strings.TrimSpace(desc) != "" {
|
||||
result["description"] = desc
|
||||
}
|
||||
default:
|
||||
result[key] = cloneGeminiSchemaValue(value)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func mergeGeminiSchemaTypes(left any, right any) string {
|
||||
leftType := geminiSchemaBranchType(left)
|
||||
rightType := geminiSchemaBranchType(right)
|
||||
|
||||
switch {
|
||||
case leftType == "":
|
||||
return rightType
|
||||
case rightType == "":
|
||||
return leftType
|
||||
case leftType == rightType:
|
||||
return leftType
|
||||
case leftType == "null":
|
||||
return rightType
|
||||
case rightType == "null":
|
||||
return leftType
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeGeminiSchemaType(raw any) string {
|
||||
typeName := geminiSchemaBranchType(raw)
|
||||
if typeName == "null" {
|
||||
return ""
|
||||
}
|
||||
return typeName
|
||||
}
|
||||
|
||||
func geminiSchemaBranchType(raw any) string {
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
if value == "null" {
|
||||
return value
|
||||
}
|
||||
if geminiSupportedTypes[value] {
|
||||
return value
|
||||
}
|
||||
return ""
|
||||
case []string:
|
||||
return geminiSchemaBranchType(stringSliceToAny(value))
|
||||
case []any:
|
||||
candidate := ""
|
||||
sawNull := false
|
||||
for _, item := range value {
|
||||
typeName, ok := item.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if typeName == "null" {
|
||||
sawNull = true
|
||||
continue
|
||||
}
|
||||
if !geminiSupportedTypes[typeName] {
|
||||
continue
|
||||
}
|
||||
if candidate == "" {
|
||||
candidate = typeName
|
||||
continue
|
||||
}
|
||||
if candidate != typeName {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if candidate == "" && sawNull {
|
||||
return "null"
|
||||
}
|
||||
return candidate
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeGeminiEnum(raw any) []any {
|
||||
values, ok := raw.([]any)
|
||||
if !ok {
|
||||
if stringValues, ok := raw.([]string); ok {
|
||||
return stringSliceToAny(stringValues)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
sanitized := make([]any, 0, len(values))
|
||||
for _, value := range values {
|
||||
switch value.(type) {
|
||||
case string, bool, float64, int, int32, int64:
|
||||
sanitized = append(sanitized, value)
|
||||
}
|
||||
}
|
||||
if len(sanitized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func sanitizeGeminiRequired(raw any, properties map[string]any) []string {
|
||||
required := requiredStrings(raw)
|
||||
if len(required) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
filtered := make([]string, 0, len(required))
|
||||
seen := make(map[string]struct{}, len(required))
|
||||
for _, name := range required {
|
||||
if _, ok := properties[name]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func requiredStrings(raw any) []string {
|
||||
switch value := raw.(type) {
|
||||
case []string:
|
||||
return append([]string(nil), value...)
|
||||
case []any:
|
||||
required := make([]string, 0, len(value))
|
||||
for _, item := range value {
|
||||
name, ok := item.(string)
|
||||
if ok {
|
||||
required = append(required, name)
|
||||
}
|
||||
}
|
||||
return required
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func mergeRequiredLists(left any, right any) []string {
|
||||
merged := make([]string, 0)
|
||||
seen := map[string]struct{}{}
|
||||
|
||||
for _, name := range append(requiredStrings(left), requiredStrings(right)...) {
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
func geminiUnionBranchScore(schema map[string]any) int {
|
||||
score := 0
|
||||
if hasSchemaProperties(schema) {
|
||||
score += 20
|
||||
}
|
||||
if hasSchemaItems(schema) {
|
||||
score += 10
|
||||
}
|
||||
if _, ok := schema["enum"]; ok {
|
||||
score += 5
|
||||
}
|
||||
if _, ok := schema["description"]; ok {
|
||||
score += 2
|
||||
}
|
||||
score += len(schema)
|
||||
return score
|
||||
}
|
||||
|
||||
func hasSchemaProperties(schema map[string]any) bool {
|
||||
props, ok := schema["properties"].(map[string]any)
|
||||
return ok && len(props) > 0
|
||||
}
|
||||
|
||||
func hasSchemaItems(schema map[string]any) bool {
|
||||
_, ok := schema["items"].(map[string]any)
|
||||
return ok
|
||||
}
|
||||
|
||||
func schemaSlice(raw any) []map[string]any {
|
||||
switch value := raw.(type) {
|
||||
case []map[string]any:
|
||||
return append([]map[string]any(nil), value...)
|
||||
case []any:
|
||||
schemas := make([]map[string]any, 0, len(value))
|
||||
for _, item := range value {
|
||||
schema, ok := item.(map[string]any)
|
||||
if ok {
|
||||
schemas = append(schemas, schema)
|
||||
}
|
||||
}
|
||||
return schemas
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func cloneGeminiSchemaMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for key, value := range in {
|
||||
out[key] = cloneGeminiSchemaValue(value)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneGeminiSchemaValue(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
return cloneGeminiSchemaMap(typed)
|
||||
case []any:
|
||||
out := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
out[i] = cloneGeminiSchemaValue(item)
|
||||
}
|
||||
return out
|
||||
case []string:
|
||||
return append([]string(nil), typed...)
|
||||
default:
|
||||
return typed
|
||||
}
|
||||
}
|
||||
|
||||
func cloneRefTrail(in map[string]struct{}) map[string]struct{} {
|
||||
if len(in) == 0 {
|
||||
return make(map[string]struct{})
|
||||
}
|
||||
out := make(map[string]struct{}, len(in))
|
||||
for key := range in {
|
||||
out[key] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stringSliceToAny(values []string) []any {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]any, len(values))
|
||||
for i, value := range values {
|
||||
result[i] = value
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package common
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeSchemaForGemini_DereferencesRefsAndFlattensUnions(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"parent": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/pageParent"},
|
||||
map[string]any{"$ref": "#/$defs/databaseParent"},
|
||||
},
|
||||
},
|
||||
"icon": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/emoji"},
|
||||
map[string]any{"type": "null"},
|
||||
},
|
||||
},
|
||||
"data": map[string]any{
|
||||
"$ref": "#/$defs/dataPayload",
|
||||
},
|
||||
},
|
||||
"required": []any{"parent", "icon", "missing"},
|
||||
"$defs": map[string]any{
|
||||
"pageParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"page_id": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []any{"page_id"},
|
||||
},
|
||||
"databaseParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"database_id": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []any{"database_id"},
|
||||
},
|
||||
"emoji": map[string]any{
|
||||
"type": "string",
|
||||
"pattern": "^:[a-z_]+:$",
|
||||
},
|
||||
"dataPayload": map[string]any{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
},
|
||||
"count": map[string]any{
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": []any{"name"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := SanitizeSchemaForGemini(schema)
|
||||
assertSchemaKeyAbsent(t, got, "$defs")
|
||||
assertSchemaKeyAbsent(t, got, "$ref")
|
||||
assertSchemaKeyAbsent(t, got, "anyOf")
|
||||
assertSchemaKeyAbsent(t, got, "oneOf")
|
||||
assertSchemaKeyAbsent(t, got, "allOf")
|
||||
assertSchemaKeyAbsent(t, got, "additionalProperties")
|
||||
assertSchemaKeyAbsent(t, got, "pattern")
|
||||
assertSchemaKeyAbsent(t, got, "minLength")
|
||||
assertSchemaKeyAbsent(t, got, "minimum")
|
||||
|
||||
if got["type"] != "object" {
|
||||
t.Fatalf("top-level type = %#v, want object", got["type"])
|
||||
}
|
||||
|
||||
props, ok := got["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("properties = %#v, want map", got["properties"])
|
||||
}
|
||||
|
||||
parent, ok := props["parent"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parent schema = %#v, want map", props["parent"])
|
||||
}
|
||||
if parent["type"] != "object" {
|
||||
t.Fatalf("parent.type = %#v, want object", parent["type"])
|
||||
}
|
||||
parentProps, ok := parent["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parent.properties = %#v, want map", parent["properties"])
|
||||
}
|
||||
if _, found := parentProps["page_id"]; !found {
|
||||
t.Fatalf("parent.properties missing page_id: %#v", parentProps)
|
||||
}
|
||||
if _, found := parentProps["database_id"]; !found {
|
||||
t.Fatalf("parent.properties missing database_id: %#v", parentProps)
|
||||
}
|
||||
if _, hasRequired := parent["required"]; hasRequired {
|
||||
t.Fatalf("parent.required = %#v, want omitted for merged anyOf branches", parent["required"])
|
||||
}
|
||||
|
||||
icon, ok := props["icon"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("icon schema = %#v, want map", props["icon"])
|
||||
}
|
||||
if icon["type"] != "string" {
|
||||
t.Fatalf("icon.type = %#v, want string", icon["type"])
|
||||
}
|
||||
|
||||
data, ok := props["data"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("data schema = %#v, want map", props["data"])
|
||||
}
|
||||
if data["type"] != "object" {
|
||||
t.Fatalf("data.type = %#v, want object", data["type"])
|
||||
}
|
||||
dataProps, ok := data["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("data.properties = %#v, want map", data["properties"])
|
||||
}
|
||||
if _, found := dataProps["name"]; !found {
|
||||
t.Fatalf("data.properties missing name: %#v", dataProps)
|
||||
}
|
||||
if _, found := dataProps["count"]; !found {
|
||||
t.Fatalf("data.properties missing count: %#v", dataProps)
|
||||
}
|
||||
|
||||
required, ok := got["required"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("required = %#v, want []string", got["required"])
|
||||
}
|
||||
if len(required) != 2 || required[0] != "parent" || required[1] != "icon" {
|
||||
t.Fatalf("required = %#v, want [parent icon]", required)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSchemaForGemini_MergesAllOfAndFiltersRequired(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"payload": map[string]any{
|
||||
"allOf": []any{
|
||||
map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"id": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": []any{"id"},
|
||||
},
|
||||
map[string]any{
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
"count": map[string]any{
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": []any{"name", "missing"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := SanitizeSchemaForGemini(schema)
|
||||
props := got["properties"].(map[string]any)
|
||||
payload := props["payload"].(map[string]any)
|
||||
|
||||
if payload["type"] != "object" {
|
||||
t.Fatalf("payload.type = %#v, want object", payload["type"])
|
||||
}
|
||||
payloadProps, ok := payload["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("payload.properties = %#v, want map", payload["properties"])
|
||||
}
|
||||
for _, key := range []string{"id", "name", "count"} {
|
||||
if _, found := payloadProps[key]; !found {
|
||||
t.Fatalf("payload.properties missing %q: %#v", key, payloadProps)
|
||||
}
|
||||
}
|
||||
|
||||
required, ok := payload["required"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("payload.required = %#v, want []string", payload["required"])
|
||||
}
|
||||
if len(required) != 2 || required[0] != "id" || required[1] != "name" {
|
||||
t.Fatalf("payload.required = %#v, want [id name]", required)
|
||||
}
|
||||
|
||||
assertSchemaKeyAbsent(t, payload, "allOf")
|
||||
assertSchemaKeyAbsent(t, payload, "minimum")
|
||||
}
|
||||
|
||||
func TestSanitizeSchemaForGemini_HandlesRecursiveRefs(t *testing.T) {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"tree": map[string]any{
|
||||
"$ref": "#/$defs/node",
|
||||
},
|
||||
},
|
||||
"$defs": map[string]any{
|
||||
"node": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
"child": map[string]any{
|
||||
"$ref": "#/$defs/node",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := SanitizeSchemaForGemini(schema)
|
||||
props := got["properties"].(map[string]any)
|
||||
tree := props["tree"].(map[string]any)
|
||||
if tree["type"] != "object" {
|
||||
t.Fatalf("tree.type = %#v, want object", tree["type"])
|
||||
}
|
||||
assertSchemaKeyAbsent(t, tree, "$ref")
|
||||
}
|
||||
|
||||
func assertSchemaKeyAbsent(t *testing.T, value any, key string) {
|
||||
t.Helper()
|
||||
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
if _, found := typed[key]; found {
|
||||
t.Fatalf("schema still contains key %q: %#v", key, typed)
|
||||
}
|
||||
for _, nested := range typed {
|
||||
assertSchemaKeyAbsent(t, nested, key)
|
||||
}
|
||||
case []any:
|
||||
for _, nested := range typed {
|
||||
assertSchemaKeyAbsent(t, nested, key)
|
||||
}
|
||||
case []string:
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
ToolSchemaTransformOff = ""
|
||||
ToolSchemaTransformSimple = "simple"
|
||||
)
|
||||
|
||||
// NormalizeToolSchemaTransform resolves user-facing aliases to a canonical
|
||||
// transform mode. Empty values and explicit "off"-style values disable schema
|
||||
// transformation.
|
||||
func NormalizeToolSchemaTransform(raw string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "", "off", "none", "native":
|
||||
return ToolSchemaTransformOff, nil
|
||||
case "simple", "basic", "strict", "flat":
|
||||
return ToolSchemaTransformSimple, nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported tool_schema_transform %q (supported: off, simple)", raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TransformToolDefinitions clones tool definitions and applies the configured
|
||||
// schema transform to function parameter schemas. When the transform is off, the
|
||||
// original slice is returned unchanged.
|
||||
func TransformToolDefinitions(tools []ToolDefinition, transform string) ([]ToolDefinition, error) {
|
||||
transform, err := NormalizeToolSchemaTransform(transform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if transform == ToolSchemaTransformOff || len(tools) == 0 {
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
out := make([]ToolDefinition, len(tools))
|
||||
for i, tool := range tools {
|
||||
out[i] = tool
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
out[i].Function = tool.Function
|
||||
out[i].Function.Parameters = transformToolSchema(tool.Function.Parameters, transform)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func transformToolSchema(schema map[string]any, transform string) map[string]any {
|
||||
switch transform {
|
||||
case ToolSchemaTransformSimple:
|
||||
return SanitizeSchemaForGoogle(schema)
|
||||
default:
|
||||
return cloneGeminiSchemaMap(schema)
|
||||
}
|
||||
}
|
||||
@@ -110,19 +110,7 @@ func ExtractProtocol(cfg *config.ModelConfig) (protocol, modelID string) {
|
||||
if provider := strings.TrimSpace(cfg.Provider); provider != "" {
|
||||
return NormalizeProvider(provider), model
|
||||
}
|
||||
if model == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
protocol, rest, found := strings.Cut(model, "/")
|
||||
if !found {
|
||||
return "openai", model
|
||||
}
|
||||
protocol = strings.TrimSpace(protocol)
|
||||
if protocol == "" {
|
||||
return "", strings.TrimSpace(rest)
|
||||
}
|
||||
return NormalizeProvider(protocol), strings.TrimSpace(rest)
|
||||
return SplitModelProviderAndID(model, "openai")
|
||||
}
|
||||
|
||||
// ResolveAPIBase returns the configured API base, or the protocol default when
|
||||
@@ -154,6 +142,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
}
|
||||
|
||||
protocol, modelID := ExtractProtocol(cfg)
|
||||
authMethod := strings.ToLower(strings.TrimSpace(cfg.AuthMethod))
|
||||
|
||||
userAgent := cfg.UserAgent
|
||||
if userAgent == "" {
|
||||
@@ -163,12 +152,12 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
switch protocol {
|
||||
case "openai":
|
||||
// OpenAI with OAuth/token auth (Codex-style)
|
||||
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
|
||||
if authMethod == "oauth" || authMethod == "token" {
|
||||
provider, err := createCodexAuthProvider()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
}
|
||||
// OpenAI with API key
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
@@ -189,7 +178,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
)
|
||||
provider.SetProviderName(protocol)
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "azure", "azure-openai":
|
||||
// Azure OpenAI uses deployment-based URLs, api-key header auth,
|
||||
@@ -202,13 +191,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
|
||||
)
|
||||
}
|
||||
return azure.NewProviderWithTimeout(
|
||||
return finalizeProviderFromConfig(azure.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
), modelID, cfg)
|
||||
|
||||
case "bedrock":
|
||||
// AWS Bedrock uses AWS SDK credentials (env vars, profiles, IAM roles, etc.)
|
||||
@@ -244,7 +233,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating bedrock provider: %w", err)
|
||||
}
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "nvidia", "venice",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
@@ -270,7 +259,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
)
|
||||
provider.SetProviderName(protocol)
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "gemini":
|
||||
if cfg.APIKey() == "" && cfg.APIBase == "" {
|
||||
@@ -280,7 +269,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if apiBase == "" {
|
||||
apiBase = getDefaultAPIBase(protocol)
|
||||
}
|
||||
return NewGeminiProvider(
|
||||
return finalizeProviderFromConfig(NewGeminiProvider(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
cfg.Proxy,
|
||||
@@ -288,7 +277,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
), modelID, cfg)
|
||||
|
||||
case "minimax":
|
||||
// Minimax requires reasoning_split: true in the request body
|
||||
@@ -317,16 +306,16 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
)
|
||||
provider.SetProviderName(protocol)
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "anthropic":
|
||||
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
|
||||
if authMethod == "oauth" || authMethod == "token" {
|
||||
// Use OAuth credentials from auth store
|
||||
provider, err := createClaudeAuthProvider()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
}
|
||||
// Use API key with HTTP API
|
||||
apiBase := cfg.APIBase
|
||||
@@ -347,7 +336,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.CustomHeaders,
|
||||
)
|
||||
provider.SetProviderName(protocol)
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
case "anthropic-messages":
|
||||
// Anthropic Messages API with native format (HTTP-based, no SDK)
|
||||
@@ -358,12 +347,12 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for anthropic-messages protocol (model: %s)", cfg.Model)
|
||||
}
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
), modelID, cfg)
|
||||
|
||||
case "coding-plan-anthropic", "alibaba-coding-anthropic":
|
||||
// Alibaba Coding Plan with Anthropic-compatible API
|
||||
@@ -374,29 +363,29 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if cfg.APIKey() == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for %q protocol (model: %s)", protocol, cfg.Model)
|
||||
}
|
||||
return anthropicmessages.NewProviderWithTimeout(
|
||||
return finalizeProviderFromConfig(anthropicmessages.NewProviderWithTimeout(
|
||||
cfg.APIKey(),
|
||||
apiBase,
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
), modelID, cfg)
|
||||
|
||||
case "antigravity":
|
||||
return NewAntigravityProvider(), modelID, nil
|
||||
return finalizeProviderFromConfig(NewAntigravityProvider(), modelID, cfg)
|
||||
|
||||
case "claude-cli", "claudecli":
|
||||
workspace := cfg.Workspace
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), modelID, nil
|
||||
return finalizeProviderFromConfig(NewClaudeCliProvider(workspace), modelID, cfg)
|
||||
|
||||
case "codex-cli", "codexcli":
|
||||
workspace := cfg.Workspace
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), modelID, nil
|
||||
return finalizeProviderFromConfig(NewCodexCliProvider(workspace), modelID, cfg)
|
||||
|
||||
case "github-copilot", "copilot":
|
||||
apiBase := cfg.APIBase
|
||||
@@ -411,15 +400,27 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return provider, modelID, nil
|
||||
return finalizeProviderFromConfig(provider, modelID, cfg)
|
||||
|
||||
default:
|
||||
return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func finalizeProviderFromConfig(
|
||||
provider LLMProvider,
|
||||
modelID string,
|
||||
cfg *config.ModelConfig,
|
||||
) (LLMProvider, string, error) {
|
||||
wrapped, err := wrapProviderWithToolSchemaTransform(provider, cfg.ToolSchemaTransform)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return wrapped, modelID, nil
|
||||
}
|
||||
|
||||
func isEmptyAPIKeyAllowed(protocol string) bool {
|
||||
meta, ok := protocolMetaByName[protocol]
|
||||
meta, ok := protocolMetaForName(protocol)
|
||||
return ok && meta.emptyAPIKeyAllowed
|
||||
}
|
||||
|
||||
@@ -439,9 +440,19 @@ func DefaultAPIBaseForProtocol(protocol string) string {
|
||||
|
||||
// getDefaultAPIBase returns the default API base URL for a given protocol.
|
||||
func getDefaultAPIBase(protocol string) string {
|
||||
meta, ok := protocolMetaByName[protocol]
|
||||
meta, ok := protocolMetaForName(protocol)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return meta.defaultAPIBase
|
||||
}
|
||||
|
||||
func protocolMetaForName(protocol string) (protocolMeta, bool) {
|
||||
if meta, ok := protocolMetaByName[protocol]; ok {
|
||||
return meta, true
|
||||
}
|
||||
if meta, ok := attachedModelProviderMetaByName[protocol]; ok {
|
||||
return meta.protocolMeta, true
|
||||
}
|
||||
return protocolMeta{}, false
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
@@ -101,6 +102,12 @@ func TestExtractProtocol(t *testing.T) {
|
||||
wantProtocol: "",
|
||||
wantModelID: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "unknown prefix falls back to openai",
|
||||
config: &config.ModelConfig{Model: "meta-llama/Llama-3.1-8B-Instruct"},
|
||||
wantProtocol: "openai",
|
||||
wantModelID: "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
{
|
||||
name: "nil config",
|
||||
wantProtocol: "",
|
||||
@@ -605,6 +612,41 @@ func TestCreateProviderFromConfig_CodexCLI(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_OpenAIMixedCaseAuthMethodUsesOAuthBranch(t *testing.T) {
|
||||
origGetCredential := getCredential
|
||||
getCredential = func(provider string) (*auth.AuthCredential, error) {
|
||||
if provider != "openai" {
|
||||
t.Fatalf("provider = %q, want %q", provider, "openai")
|
||||
}
|
||||
return &auth.AuthCredential{
|
||||
AccessToken: "test-token",
|
||||
AccountID: "acct-test",
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
getCredential = origGetCredential
|
||||
})
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-openai-oauth",
|
||||
Model: "openai/gpt-5.4",
|
||||
AuthMethod: "OAuth",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "gpt-5.4" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "gpt-5.4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-no-key",
|
||||
@@ -619,8 +661,9 @@ func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) {
|
||||
|
||||
func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-unknown",
|
||||
Model: "unknown-protocol/model",
|
||||
ModelName: "test-unknown-provider",
|
||||
Provider: "unknown-protocol",
|
||||
Model: "model",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
@@ -630,6 +673,26 @@ func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_UnknownModelPrefixDefaultsToOpenAI(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-unknown-model-prefix",
|
||||
Model: "meta-llama/Llama-3.1-8B-Instruct",
|
||||
APIBase: "https://api.example.com/v1",
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "meta-llama/Llama-3.1-8B-Instruct" {
|
||||
t.Fatalf("modelID = %q, want full model ID", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_NilConfig(t *testing.T) {
|
||||
_, _, err := CreateProviderFromConfig(nil)
|
||||
if err == nil {
|
||||
@@ -889,6 +952,71 @@ func TestGetDefaultAPIBase_QwenUSAliases(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelProviderOptions(t *testing.T) {
|
||||
options := ModelProviderOptions()
|
||||
if len(options) == 0 {
|
||||
t.Fatal("ModelProviderOptions() returned no options")
|
||||
}
|
||||
|
||||
seen := make(map[string]ModelProviderOption, len(options))
|
||||
for _, option := range options {
|
||||
seen[option.ID] = option
|
||||
}
|
||||
|
||||
if _, ok := seen["openai"]; !ok {
|
||||
t.Fatal("openai option missing")
|
||||
}
|
||||
if option, ok := seen["openai"]; ok && !option.CreateAllowed {
|
||||
t.Fatal("openai should be creatable")
|
||||
}
|
||||
if option, ok := seen["lmstudio"]; !ok {
|
||||
t.Fatal("lmstudio option missing")
|
||||
} else if !option.EmptyAPIKeyAllowed {
|
||||
t.Fatal("lmstudio should allow empty API keys")
|
||||
}
|
||||
if option, ok := seen["anthropic"]; !ok {
|
||||
t.Fatal("anthropic option missing")
|
||||
} else if option.DefaultAPIBase != "https://api.anthropic.com/v1" {
|
||||
t.Fatalf("anthropic default_api_base = %q, want %q", option.DefaultAPIBase, "https://api.anthropic.com/v1")
|
||||
}
|
||||
if _, ok := seen["azure"]; !ok {
|
||||
t.Fatal("azure option missing")
|
||||
}
|
||||
if option, ok := seen["bedrock"]; !ok {
|
||||
t.Fatal("bedrock option missing")
|
||||
} else if !option.CreateAllowed {
|
||||
t.Fatal("bedrock should be creatable and defer credential/build errors to runtime")
|
||||
}
|
||||
if option, ok := seen["elevenlabs"]; !ok {
|
||||
t.Fatal("elevenlabs option missing")
|
||||
} else {
|
||||
if option.DefaultAPIBase != "https://api.elevenlabs.io" {
|
||||
t.Fatalf("elevenlabs default_api_base = %q, want %q", option.DefaultAPIBase, "https://api.elevenlabs.io")
|
||||
}
|
||||
if option.DefaultModelAllowed {
|
||||
t.Fatal("elevenlabs should be ASR-only and therefore not allowed as a default chat model")
|
||||
}
|
||||
}
|
||||
if option, ok := seen["antigravity"]; !ok {
|
||||
t.Fatal("antigravity option missing")
|
||||
} else {
|
||||
if !option.CreateAllowed {
|
||||
t.Fatal("antigravity should be creatable")
|
||||
}
|
||||
if option.DefaultAuthMethod != "oauth" {
|
||||
t.Fatalf("antigravity default_auth_method = %q, want %q", option.DefaultAuthMethod, "oauth")
|
||||
}
|
||||
if !option.AuthMethodLocked {
|
||||
t.Fatal("antigravity auth method should be locked")
|
||||
}
|
||||
}
|
||||
if option, ok := seen["github-copilot"]; !ok {
|
||||
t.Fatal("github-copilot option missing")
|
||||
} else if option.DefaultAPIBase != "localhost:4321" {
|
||||
t.Fatalf("github-copilot default_api_base = %q, want %q", option.DefaultAPIBase, "localhost:4321")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_MinimaxInjectsReasoningSplit(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
@@ -1202,3 +1330,42 @@ func TestCreateProviderFromConfig_BedrockWithEndpointURL(t *testing.T) {
|
||||
// Unexpected error - fail the test
|
||||
t.Errorf("unexpected error from bedrock provider: %v", err)
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_ToolSchemaTransformWrapsProvider(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "claude-cli-test",
|
||||
Provider: "claude-cli",
|
||||
Model: "claude-sonnet-4.6",
|
||||
Workspace: t.TempDir(),
|
||||
ToolSchemaTransform: "simple",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if modelID != "claude-sonnet-4.6" {
|
||||
t.Fatalf("modelID = %q, want %q", modelID, "claude-sonnet-4.6")
|
||||
}
|
||||
if _, ok := provider.(*toolSchemaTransformProvider); !ok {
|
||||
t.Fatalf("provider = %T, want *toolSchemaTransformProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_InvalidToolSchemaTransform(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "claude-cli-test",
|
||||
Provider: "claude-cli",
|
||||
Model: "claude-sonnet-4.6",
|
||||
Workspace: t.TempDir(),
|
||||
ToolSchemaTransform: "invalid",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for invalid tool_schema_transform")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "tool_schema_transform") {
|
||||
t.Fatalf("error = %v, want mention tool_schema_transform", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,66 +12,6 @@ func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake
|
||||
return ""
|
||||
}
|
||||
|
||||
var geminiUnsupportedKeywords = map[string]bool{
|
||||
"patternProperties": true,
|
||||
"additionalProperties": true,
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
"examples": true,
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"multipleOf": true,
|
||||
"pattern": true,
|
||||
"format": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"uniqueItems": true,
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
}
|
||||
|
||||
func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
for k, v := range schema {
|
||||
if geminiUnsupportedKeywords[k] {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
result[k] = sanitizeSchemaForGemini(val)
|
||||
case []any:
|
||||
sanitized := make([]any, len(val))
|
||||
for i, item := range val {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
sanitized[i] = sanitizeSchemaForGemini(m)
|
||||
} else {
|
||||
sanitized[i] = item
|
||||
}
|
||||
}
|
||||
result[k] = sanitized
|
||||
default:
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if _, hasProps := result["properties"]; hasProps {
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "object"
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func extractProtocol(model string) (protocol, modelID string) {
|
||||
model = strings.TrimSpace(model)
|
||||
protocol, modelID, found := strings.Cut(model, "/")
|
||||
|
||||
@@ -264,7 +264,7 @@ func (p *GeminiProvider) buildRequestBody(
|
||||
funcDecls = append(funcDecls, geminiFunctionDeclaration{
|
||||
Name: t.Function.Name,
|
||||
Description: t.Function.Description,
|
||||
Parameters: sanitizeSchemaForGemini(t.Function.Parameters),
|
||||
Parameters: t.Function.Parameters,
|
||||
})
|
||||
}
|
||||
if len(funcDecls) > 0 {
|
||||
|
||||
@@ -259,6 +259,64 @@ func TestGeminiProvider_ChatStreamSkipsEmptyDataFrames(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_BuildRequestBody_PreservesComplexToolSchemasByDefault(t *testing.T) {
|
||||
provider := NewGeminiProvider("test-key", "https://example.com/v1beta", "", "", 0, nil, nil)
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"parent": map[string]any{
|
||||
"anyOf": []any{
|
||||
map[string]any{"$ref": "#/$defs/pageParent"},
|
||||
map[string]any{"$ref": "#/$defs/databaseParent"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"$defs": map[string]any{
|
||||
"pageParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"page_id": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"page_id"},
|
||||
},
|
||||
"databaseParent": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"database_id": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"database_id"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body := provider.buildRequestBody(
|
||||
[]Message{{Role: "user", Content: "hello"}},
|
||||
[]ToolDefinition{{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "mcp_notion_create",
|
||||
Description: "Create a Notion object",
|
||||
Parameters: schema,
|
||||
},
|
||||
}},
|
||||
"gemini-3-flash-preview",
|
||||
nil,
|
||||
)
|
||||
|
||||
tools, ok := body["tools"].([]geminiTool)
|
||||
if !ok || len(tools) != 1 {
|
||||
t.Fatalf("tools = %#v, want one geminiTool", body["tools"])
|
||||
}
|
||||
got, ok := tools[0].FunctionDeclarations[0].Parameters.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters = %#v, want map", tools[0].FunctionDeclarations[0].Parameters)
|
||||
}
|
||||
|
||||
if got["$defs"] == nil {
|
||||
t.Fatalf("parameters = %#v, want raw schema with $defs preserved by default", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiProvider_ChatStreamReturnsErrorOnInvalidDataFrame(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
+15
-10
@@ -17,18 +17,13 @@ func ParseModelRef(raw string, defaultProvider string) *ModelRef {
|
||||
return nil
|
||||
}
|
||||
|
||||
if idx := strings.Index(raw, "/"); idx > 0 {
|
||||
provider := NormalizeProvider(raw[:idx])
|
||||
model := strings.TrimSpace(raw[idx+1:])
|
||||
if model == "" {
|
||||
return nil
|
||||
}
|
||||
return &ModelRef{Provider: provider, Model: model}
|
||||
provider, model := SplitModelProviderAndID(raw, defaultProvider)
|
||||
if model == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &ModelRef{
|
||||
Provider: NormalizeProvider(defaultProvider),
|
||||
Model: raw,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +48,8 @@ func NormalizeProvider(provider string) string {
|
||||
return "zhipu"
|
||||
case "google":
|
||||
return "gemini"
|
||||
case "google-antigravity":
|
||||
return "antigravity"
|
||||
case "alibaba-coding", "qwen-coding":
|
||||
return "coding-plan"
|
||||
case "alibaba-coding-anthropic":
|
||||
@@ -61,6 +58,14 @@ func NormalizeProvider(provider string) string {
|
||||
return "qwen-intl"
|
||||
case "dashscope-us":
|
||||
return "qwen-us"
|
||||
case "azure-openai":
|
||||
return "azure"
|
||||
case "claudecli":
|
||||
return "claude-cli"
|
||||
case "codexcli":
|
||||
return "codex-cli"
|
||||
case "copilot":
|
||||
return "github-copilot"
|
||||
}
|
||||
|
||||
return p
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user