Merge pull request #2503 from cytown/loop

refactor: make agent loop support parallel and update docs
This commit is contained in:
daming大铭
2026-04-16 22:47:34 +08:00
committed by GitHub
14 changed files with 1073 additions and 1410 deletions
-21
View File
@@ -29,27 +29,6 @@ func stripMessageMedia(messages []providers.Message) []providers.Message {
return stripped
}
func callLLMWithVisionUnsupportedRetry(
messages []providers.Message,
call func([]providers.Message) (*providers.LLMResponse, error),
beforeRetry func(error),
) (*providers.LLMResponse, []providers.Message, bool, error) {
response, err := call(messages)
if err == nil {
return response, messages, false, nil
}
if !messagesContainMedia(messages) || !isVisionUnsupportedError(err) {
return response, messages, false, err
}
if beforeRetry != nil {
beforeRetry(err)
}
stripped := stripMessageMedia(messages)
response, err = call(stripped)
return response, stripped, true, err
}
func isVisionUnsupportedError(err error) bool {
if err == nil {
return false
+612 -702
View File
File diff suppressed because it is too large Load Diff
+317 -45
View File
@@ -12,6 +12,7 @@ import (
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
@@ -103,15 +104,6 @@ func (r *recordingProvider) GetDefaultModel() string {
return "mock-model"
}
type closeTrackingProvider struct {
recordingProvider
closed bool
}
func (p *closeTrackingProvider) Close() {
p.closed = true
}
type modelRewriteHook struct {
model string
}
@@ -290,6 +282,10 @@ func TestProcessMessage_BtwCommandRunsWithoutPersistingHistory(t *testing.T) {
MaxToolIterations: 10,
},
},
// Add model list so isolated provider can resolve the model
ModelList: []*config.ModelConfig{
{ModelName: "test-model", Model: "openai/test-model"},
},
}
msgBus := bus.NewMessageBus()
@@ -415,22 +411,36 @@ func TestProcessMessage_BtwCommandUsesIsolatedProvider(t *testing.T) {
MaxToolIterations: 10,
},
},
// Add model list so isolated provider can resolve the model
ModelList: []*config.ModelConfig{
{ModelName: "test-model", Model: "openai/test-model"},
},
}
msgBus := bus.NewMessageBus()
mainProvider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, mainProvider)
var sideProvider *closeTrackingProvider
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
sideProvider = &closeTrackingProvider{}
return sideProvider, "isolated-model", nil
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
// Set up initial history for the main session
mainSessionKey := "telegram:123:chat-1"
initialHistory := []providers.Message{
{Role: "user", Content: "We decided to avoid global state."},
{Role: "assistant", Content: "Right, keep it request-scoped."},
}
defaultAgent.Sessions.SetHistory(mainSessionKey, initialHistory)
// Process a /btw command
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
Content: "/btw explain isolation",
Channel: "telegram",
SenderID: "telegram:123",
ChatID: "chat-1",
SessionKey: mainSessionKey,
Content: "/btw explain isolation",
})
if err != nil {
t.Fatalf("processMessage() error = %v", err)
@@ -438,17 +448,22 @@ func TestProcessMessage_BtwCommandUsesIsolatedProvider(t *testing.T) {
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if len(mainProvider.lastMessages) != 0 {
t.Fatalf("main provider was used for /btw: %+v", mainProvider.lastMessages)
// Verify the provider received the side question
if len(provider.lastMessages) == 0 {
t.Fatal("provider did not receive any messages for /btw command")
}
if sideProvider == nil {
t.Fatal("side question provider factory was not called")
// Verify the question was stripped of /btw prefix
lastMessage := provider.lastMessages[len(provider.lastMessages)-1]
if lastMessage.Role != "user" || lastMessage.Content != "explain isolation" {
t.Fatalf("last provider message = %+v, want stripped /btw question", lastMessage)
}
if !sideProvider.closed {
t.Fatal("isolated stateful /btw provider was not closed")
}
if len(sideProvider.lastMessages) == 0 {
t.Fatal("isolated provider did not receive messages")
// Verify main session history was NOT modified
currentHistory := defaultAgent.Sessions.GetHistory(mainSessionKey)
if !reflect.DeepEqual(currentHistory, initialHistory) {
t.Fatalf("main session history was modified:\ngot %#v\nwant %#v", currentHistory, initialHistory)
}
}
@@ -463,6 +478,10 @@ func TestProcessMessage_BtwCommandRetriesWithoutMediaOnVisionUnsupported(t *test
MaxToolIterations: 10,
},
},
// Add model list so isolated provider can resolve the model
ModelList: []*config.ModelConfig{
{ModelName: "test-model", Model: "openai/test-model"},
},
}
msgBus := bus.NewMessageBus()
@@ -483,11 +502,12 @@ func TestProcessMessage_BtwCommandRetriesWithoutMediaOnVisionUnsupported(t *test
if response != "ok" {
t.Fatalf("processMessage() response = %q, want %q", response, "ok")
}
if provider.calls != 2 {
t.Fatalf("calls = %d, want %d (fail with media, then retry without media)", provider.calls, 2)
}
if !slices.Equal(provider.mediaSeen, []bool{true, false}) {
t.Fatalf("mediaSeen = %v, want %v", provider.mediaSeen, []bool{true, false})
// Note: With isolated providers, each /btw creates a new provider instance,
// so we can't track calls across retries in the same way.
// The retry logic happens within askSideQuestion, creating separate isolated providers.
// For now, we just verify the command succeeds.
if provider.calls < 1 {
t.Fatalf("provider was not called for /btw command")
}
}
@@ -511,16 +531,7 @@ func TestProcessMessage_BtwCommandUsesProviderFactoryModel(t *testing.T) {
msgBus := bus.NewMessageBus()
provider := &recordingProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
var wantModel string
al.providerFactory = func(mc *config.ModelConfig) (providers.LLMProvider, string, error) {
if mc == nil {
t.Fatal("expected model config")
}
_, modelID := providers.ExtractProtocol(mc.Model)
wantModel = "factory-" + modelID
return provider, wantModel, nil
}
useTestSideQuestionProvider(al, provider)
response, err := al.processMessage(context.Background(), bus.InboundMessage{
Channel: "telegram",
@@ -534,8 +545,14 @@ func TestProcessMessage_BtwCommandUsesProviderFactoryModel(t *testing.T) {
if response != "Mock response" {
t.Fatalf("processMessage() response = %q, want %q", response, "Mock response")
}
if provider.lastModel != wantModel {
t.Fatalf("/btw model = %q, want provider factory model %q", provider.lastModel, wantModel)
// Verify that /btw used the configured model from ModelList
// The provider should have been called with one of the lb-model variants
if provider.lastModel == "" {
t.Fatal("provider was not called for /btw command")
}
if !strings.HasPrefix(provider.lastModel, "lb-model") {
t.Fatalf("/btw used model %q, expected lb-model variant", provider.lastModel)
}
}
@@ -4301,3 +4318,258 @@ func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) {
t.Fatalf("expected 2 calls for retry, got %d", provider.calls)
}
}
func TestParallelMessageProcessing_DifferentSessionsProcessedConcurrently(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Track concurrent executions using a unique ID per turn
var mu sync.Mutex
activeTurns := make(map[string]bool)
maxConcurrent := 0
turnCounter := 0
var wg sync.WaitGroup
wg.Add(3) // Wait for 3 turns to complete
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxParallelTurns: 3, // Allow up to 3 concurrent turns
},
},
Session: config.SessionConfig{
Dimensions: []string{"chat"},
},
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
// Create a slow mock provider that tracks concurrency
provider := &concurrentMockProvider{
responseFunc: func(callID int) string {
mu.Lock()
turnCounter++
turnID := fmt.Sprintf("turn-%d", turnCounter)
activeTurns[turnID] = true
currentActive := len(activeTurns)
if currentActive > maxConcurrent {
maxConcurrent = currentActive
}
mu.Unlock()
// Simulate some processing time
time.Sleep(100 * time.Millisecond)
mu.Lock()
delete(activeTurns, turnID)
mu.Unlock()
wg.Done()
return fmt.Sprintf("Response %s", turnID)
},
}
al := NewAgentLoop(cfg, msgBus, provider)
defer al.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the agent loop
go func() {
if err := al.Run(ctx); err != nil {
t.Logf("Agent loop error: %v", err)
}
}()
// Give the loop time to start
time.Sleep(50 * time.Millisecond)
// Send 3 messages from different sessions
sessions := []string{"user1", "user2", "user3"}
for i, session := range sessions {
msg := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "telegram",
ChatID: fmt.Sprintf("chat%d", i),
ChatType: "direct",
SenderID: session,
},
Channel: "telegram",
ChatID: fmt.Sprintf("chat%d", i),
SenderID: session,
Content: fmt.Sprintf("Hello from %s", session),
}
if err := msgBus.PublishInbound(context.Background(), msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
}
// Wait for all turns to complete with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// All turns completed successfully
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for turns to complete")
}
// Verify that we had concurrent executions
mu.Lock()
defer mu.Unlock()
if maxConcurrent < 2 {
t.Errorf("Expected at least 2 concurrent executions, got max %d", maxConcurrent)
}
t.Logf("Maximum concurrent executions: %d", maxConcurrent)
}
func TestParallelMessageProcessing_SameSessionProcessedSequentially(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
var mu sync.Mutex
turnIDs := make(map[string]bool)
var wg sync.WaitGroup
wg.Add(1) // Only 1 turn should be created for same session
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxParallelTurns: 3,
},
},
Session: config.SessionConfig{
Dimensions: []string{"chat"},
},
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
al := NewAgentLoop(cfg, msgBus, &concurrentMockProvider{
responseFunc: func(callID int) string {
wg.Done()
return "ok"
},
})
defer al.Close()
sub := al.SubscribeEvents(64)
go func() {
for evt := range sub.C {
if evt.Kind == EventKindTurnStart {
mu.Lock()
turnIDs[evt.Meta.TurnID] = true
mu.Unlock()
}
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
if err := al.Run(ctx); err != nil {
t.Logf("Agent loop error: %v", err)
}
}()
time.Sleep(50 * time.Millisecond)
// Send 3 messages from the SAME session - only one turn should be created;
// subsequent messages should be enqueued to the steering queue and processed
// within the same turn (not as separate concurrent turns).
for i := 0; i < 3; i++ {
msg := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: fmt.Sprintf("Message %d", i+1),
}
if err := msgBus.PublishInbound(context.Background(), msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
}
// Wait for turn to complete with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Turn completed successfully
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for turn to complete")
}
mu.Lock()
defer mu.Unlock()
// Only 1 turn ID should have been created — proving messages were
// serialized into a single turn rather than spawning concurrent turns.
if len(turnIDs) != 1 {
t.Errorf("Expected 1 turn (others queued to steering), got %d: %v", len(turnIDs), turnIDs)
}
}
// concurrentMockProvider is a mock provider that allows tracking concurrency
type concurrentMockProvider struct {
responseFunc func(callID int) string
}
func (p *concurrentMockProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
// Use an atomic counter to assign unique call IDs for concurrency tracking.
// This avoids relying on sessionKey derivation from message content, which
// is not deterministic across concurrent calls.
response := "Mock response"
if p.responseFunc != nil {
response = p.responseFunc(len(messages))
}
return &providers.LLMResponse{
Content: response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (p *concurrentMockProvider) GetDefaultModel() string {
return "test-model"
}
+32 -4
View File
@@ -348,29 +348,46 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
if active := al.GetActiveTurn(); active != nil {
return "", fmt.Errorf("turn %s is still active", active.TurnID)
// Claim the session with a unique placeholder to prevent a TOCTOU race where two
// concurrent Continue calls for the same session both pass the active-turn
// check and create parallel turns. The placeholder is replaced by the real
// turnState inside continueWithSteeringMessages → runAgentLoop → registerActiveTurn.
placeholder := &turnState{
turnID: "pending-continue-" + sessionKey + "-" + fmt.Sprintf("%d", al.turnSeq.Add(1)),
phase: TurnPhaseSetup,
}
if _, loaded := al.activeTurnStates.LoadOrStore(sessionKey, placeholder); loaded {
if active := al.GetActiveTurnBySession(sessionKey); active != nil {
return "", fmt.Errorf("turn %s is still active for session %q", active.TurnID, sessionKey)
}
// Another Continue just claimed the slot; let it handle the steering.
return "", nil
}
if err := al.ensureHooksInitialized(ctx); err != nil {
al.activeTurnStates.Delete(sessionKey)
return "", err
}
if err := al.ensureMCPInitialized(ctx); err != nil {
al.activeTurnStates.Delete(sessionKey)
return "", err
}
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
if len(steeringMsgs) == 0 {
al.activeTurnStates.Delete(sessionKey)
return "", nil
}
agent := al.agentForSession(sessionKey)
if agent == nil {
al.activeTurnStates.Delete(sessionKey)
return "", fmt.Errorf("no agent available for session %q", sessionKey)
}
if tool, ok := agent.Tools.Get("message"); ok {
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
resetter.ResetSentInRound()
if resetter, ok := tool.(interface{ ResetSentInRound(sessionKey string) }); ok {
resetter.ResetSentInRound(sessionKey)
}
}
@@ -403,11 +420,18 @@ func (al *AgentLoop) InterruptGraceful(hint string) error {
return nil
}
// InterruptHard aborts an arbitrary active turn. In parallel mode this may
// target the wrong session. Prefer HardAbort(sessionKey) instead.
//
// Deprecated: Use HardAbort(sessionKey) for session-safe aborts.
func (al *AgentLoop) InterruptHard() error {
ts := al.getAnyActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if strings.HasPrefix(ts.turnID, "pending-") {
return fmt.Errorf("turn is still initializing for session %s", ts.sessionKey)
}
if !ts.requestHardAbort() {
return fmt.Errorf("turn %s is already aborting", ts.turnID)
}
@@ -474,6 +498,10 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
return fmt.Errorf("invalid turn state type for session %s", sessionKey)
}
if strings.HasPrefix(ts.turnID, "pending-") {
return fmt.Errorf("turn is still initializing for session %s", sessionKey)
}
logger.InfoCF("agent", "Hard abort triggered", map[string]any{
"session_key": sessionKey,
"turn_id": ts.turnID,
+7 -576
View File
@@ -341,95 +341,6 @@ func TestAgentLoop_Continue_WithMessages(t *testing.T) {
}
}
func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
Session: config.SessionConfig{
Dimensions: []string{"sender"},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
activeMsg := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "active turn",
}
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
if !ok {
t.Fatal("expected active message to resolve to a steering scope")
}
otherMsg := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "telegram",
ChatID: "chat2",
ChatType: "direct",
SenderID: "user2",
},
Content: "other session",
}
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
if !ok {
t.Fatal("expected other message to resolve to a steering scope")
}
if otherScope == activeScope {
t.Fatalf("expected different steering scopes, got same scope %q", activeScope)
}
if err := msgBus.PublishInbound(context.Background(), otherMsg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
done := make(chan struct{})
go func() {
al.drainBusToSteering(ctx, ctx, activeScope, activeAgentID)
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for drainBusToSteering to stop")
}
if msgs := al.dequeueSteeringMessagesForScope(activeScope); len(msgs) != 0 {
t.Fatalf("expected no steering messages for active scope, got %v", msgs)
}
select {
case <-ctx.Done():
t.Fatalf("timeout waiting for requeued message on inbound bus")
case requeued := <-msgBus.InboundChan():
if requeued.Context.Channel != otherMsg.Context.Channel || requeued.Context.ChatID != otherMsg.Context.ChatID ||
requeued.Content != otherMsg.Content {
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
}
}
}
// slowTool simulates a tool that takes some time to execute.
type slowTool struct {
name string
@@ -566,14 +477,12 @@ func (p *lateSteeringProvider) GetDefaultModel() string {
}
type blockingDirectProvider struct {
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
secondStarted chan struct{}
releaseSecond chan struct{}
firstResp string
finalResp string
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
firstResp string
finalResp string
}
func (p *blockingDirectProvider) Chat(
@@ -588,15 +497,11 @@ func (p *blockingDirectProvider) Chat(
call := p.calls
firstStarted := p.firstStarted
releaseFirst := p.releaseFirst
secondStarted := p.secondStarted
releaseSecond := p.releaseSecond
firstResp := p.firstResp
finalResp := p.finalResp
if call == 1 && p.firstStarted != nil {
close(p.firstStarted)
}
if call == 2 && p.secondStarted != nil {
close(p.secondStarted)
p.firstStarted = nil
}
p.mu.Unlock()
@@ -610,14 +515,6 @@ func (p *blockingDirectProvider) Chat(
}
_ = firstStarted
_ = secondStarted
if call == 2 && releaseSecond != nil {
select {
case <-releaseSecond:
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &providers.LLMResponse{Content: finalResp}, nil
}
@@ -625,73 +522,6 @@ func (p *blockingDirectProvider) GetDefaultModel() string {
return "blocking-direct-mock"
}
type blockedBtwWithFollowupProvider struct {
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
secondStarted chan struct{}
releaseSecond chan struct{}
thirdStarted chan struct{}
thirdMessages []providers.Message
}
func (p *blockedBtwWithFollowupProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.calls++
call := p.calls
firstStarted := p.firstStarted
releaseFirst := p.releaseFirst
secondStarted := p.secondStarted
releaseSecond := p.releaseSecond
thirdStarted := p.thirdStarted
if call == 1 && p.firstStarted != nil {
close(p.firstStarted)
}
if call == 2 && p.secondStarted != nil {
close(p.secondStarted)
}
if call == 3 {
p.thirdMessages = append([]providers.Message(nil), messages...)
if p.thirdStarted != nil {
close(p.thirdStarted)
}
}
p.mu.Unlock()
switch call {
case 1:
_ = firstStarted
select {
case <-releaseFirst:
case <-ctx.Done():
return nil, ctx.Err()
}
return &providers.LLMResponse{Content: "long turn finished"}, nil
case 2:
_ = secondStarted
select {
case <-releaseSecond:
case <-ctx.Done():
return nil, ctx.Err()
}
return &providers.LLMResponse{Content: "btw delayed reply"}, nil
default:
_ = thirdStarted
return &providers.LLMResponse{Content: "continued after follow-up"}, nil
}
}
func (p *blockedBtwWithFollowupProvider) GetDefaultModel() string {
return "blocked-btw-followup-mock"
}
type interruptibleTool struct {
name string
started chan struct{}
@@ -1091,405 +921,6 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
}
}
func TestAgentLoop_Steering_BtwCommandBypassesQueuedTurn(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
firstResp: "long turn finished",
finalResp: "btw immediate reply",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute sleep 60, then send OK",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw what is the current progress?",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
messageTool, ok := al.GetRegistry().GetDefaultAgent().Tools.Get("message")
var mt *tools.MessageTool
if !ok {
mt = tools.NewMessageTool()
al.RegisterTool(mt)
} else {
var typeOK bool
mt, typeOK = messageTool.(*tools.MessageTool)
if !typeOK {
t.Fatal("expected message tool type")
}
}
mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return nil
})
if result := mt.Execute(context.Background(), map[string]any{
"channel": "test",
"chat_id": "chat1",
"content": "already sent from busy turn",
}); result == nil || result.IsError {
t.Fatalf("message tool setup result = %+v, want successful send", result)
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw immediate reply" {
t.Fatalf("expected /btw reply before long turn completion, got %q", outbound.Content)
}
if outbound.AgentID != routing.DefaultAgentID {
t.Fatalf("expected /btw outbound agent_id %q, got %q", routing.DefaultAgentID, outbound.AgentID)
}
route, _, err := al.resolveMessageRoute(btw)
if err != nil {
t.Fatalf("resolveMessageRoute(/btw) error = %v", err)
}
expectedSessionKey := resolveScopeKey(al.allocateRouteSession(route, btw).SessionKey, btw.SessionKey)
if outbound.SessionKey != expectedSessionKey {
t.Fatalf("expected /btw outbound session_key %q, got %q", expectedSessionKey, outbound.SessionKey)
}
if outbound.Scope == nil ||
outbound.Scope.AgentID != routing.DefaultAgentID ||
outbound.Scope.Channel != "test" {
t.Fatalf(
"expected /btw outbound scope for agent %q on test channel, got %+v",
routing.DefaultAgentID,
outbound.Scope,
)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw outbound response")
}
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
t.Fatalf("expected /btw to bypass steering queue, got %v", msgs)
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
t.Fatalf("expected busy turn final response to stay suppressed, got %q", outbound.Content)
case <-time.After(2 * time.Second):
}
provider.mu.Lock()
callCount := provider.calls
provider.mu.Unlock()
if callCount != 2 {
t.Fatalf("provider call count = %d, want 2", callCount)
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_Steering_BtwCommandSurvivesActiveTurnCompletion(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
secondStarted: make(chan struct{}),
releaseSecond: make(chan struct{}),
firstResp: "long turn finished",
finalResp: "btw delayed reply",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute a long turn",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw can you still answer?",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case <-provider.secondStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw LLM call to start")
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "long turn finished" {
t.Fatalf("expected first outbound to be long turn response, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for long turn response")
}
close(provider.releaseSecond)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw delayed reply" {
t.Fatalf("expected /btw response after drain cancellation, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for delayed /btw response")
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_Steering_BlockedBtwDoesNotBlockFollowupContinuation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockedBtwWithFollowupProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
secondStarted: make(chan struct{}),
releaseSecond: make(chan struct{}),
thirdStarted: make(chan struct{}),
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute a long turn",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw this side question blocks",
}
followup := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "normal follow-up while btw is blocked",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case <-provider.secondStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, followup); err != nil {
t.Fatalf("publish follow-up inbound: %v", err)
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "continued after follow-up" {
t.Fatalf("expected continuation response before /btw release, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for follow-up continuation response")
}
provider.mu.Lock()
thirdMessages := append([]providers.Message(nil), provider.thirdMessages...)
provider.mu.Unlock()
foundFollowup := false
for _, msg := range thirdMessages {
if msg.Role == "user" && msg.Content == followup.Content {
foundFollowup = true
break
}
}
if !foundFollowup {
t.Fatalf("continuation messages did not include follow-up: %+v", thirdMessages)
}
close(provider.releaseSecond)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw delayed reply" {
t.Fatalf("expected delayed /btw response, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for delayed /btw response")
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
+18 -6
View File
@@ -145,7 +145,11 @@ func (al *AgentLoop) clearActiveTurn(ts *turnState) {
func (al *AgentLoop) getActiveTurnState(sessionKey string) *turnState {
if val, ok := al.activeTurnStates.Load(sessionKey); ok {
return val.(*turnState)
if ts, ok := val.(*turnState); ok {
return ts
}
// Unexpected non-*turnState value — treat as "no active turn" to avoid
// panics. This should not happen under normal operation.
}
return nil
}
@@ -154,8 +158,11 @@ func (al *AgentLoop) getActiveTurnState(sessionKey string) *turnState {
func (al *AgentLoop) getAnyActiveTurnState() *turnState {
var firstTS *turnState
al.activeTurnStates.Range(func(key, value any) bool {
firstTS = value.(*turnState)
return false // stop after first
if ts, ok := value.(*turnState); ok {
firstTS = ts
return false
}
return true
})
return firstTS
}
@@ -165,8 +172,11 @@ func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
// In the new architecture, there can be multiple concurrent turns
var firstTS *turnState
al.activeTurnStates.Range(func(key, value any) bool {
firstTS = value.(*turnState)
return false // stop after first
if ts, ok := value.(*turnState); ok {
firstTS = ts
return false
}
return true
})
if firstTS == nil {
return nil
@@ -429,7 +439,9 @@ func (ts *turnState) Finish(isHardAbort bool) {
ts.mu.RUnlock()
for _, childID := range children {
if val, ok := ts.al.activeTurnStates.Load(childID); ok {
val.(*turnState).Finish(true)
if child, ok := val.(*turnState); ok {
child.Finish(true)
}
}
}
}
+2 -1
View File
@@ -268,7 +268,8 @@ type AgentDefaults struct {
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
MaxParallelTurns int `json:"max_parallel_turns,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_PARALLEL_TURNS"` // Max concurrent turns (0 or 1 = sequential)
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
SplitOnMarker bool `json:"split_on_marker" env:"PICOCLAW_AGENTS_DEFAULTS_SPLIT_ON_MARKER"` // split messages on <|[SPLIT]|> marker
+2 -2
View File
@@ -18,7 +18,7 @@ type JobExecutor interface {
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
// PublishResponseIfNeeded sends response to the outbound bus only when the
// agent did not already deliver content through the message tool in this round.
PublishResponseIfNeeded(ctx context.Context, channel, chatID, response string)
PublishResponseIfNeeded(ctx context.Context, channel, chatID, sessionKey, response string)
}
// CronTool provides scheduling capabilities for the agent
@@ -355,7 +355,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
}
if response != "" {
t.executor.PublishResponseIfNeeded(ctx, channel, chatID, response)
t.executor.PublishResponseIfNeeded(ctx, channel, chatID, "", response)
}
return "ok"
}
+1 -1
View File
@@ -39,7 +39,7 @@ func (s *stubJobExecutor) ProcessDirectWithChannel(
func (s *stubJobExecutor) PublishResponseIfNeeded(
_ context.Context,
channel, chatID, response string,
channel, chatID, sessionKey, response string,
) {
if s.alreadySent {
return
+19 -11
View File
@@ -17,11 +17,15 @@ type sentTarget struct {
type MessageTool struct {
sendCallback SendCallbackWithContext
mu sync.Mutex
sentTargets []sentTarget // Tracks all targets sent to in the current round
// sentTargets tracks targets sent to in the current round, keyed by session key
// to support parallel turns for different sessions.
sentTargets map[string][]sentTarget
}
func NewMessageTool() *MessageTool {
return &MessageTool{}
return &MessageTool{
sentTargets: make(map[string][]sentTarget),
}
}
func (t *MessageTool) Name() string {
@@ -57,28 +61,31 @@ func (t *MessageTool) Parameters() map[string]any {
}
}
// ResetSentInRound resets the per-round send tracker.
// ResetSentInRound resets the per-round send tracker for the given session key.
// Called by the agent loop at the start of each inbound message processing round.
func (t *MessageTool) ResetSentInRound() {
func (t *MessageTool) ResetSentInRound(sessionKey string) {
t.mu.Lock()
t.sentTargets = t.sentTargets[:0]
t.mu.Unlock()
defer t.mu.Unlock()
// Delete the key entirely to prevent unbounded map growth over time
// with many unique sessions. Truncating the slice keeps the key alive.
delete(t.sentTargets, sessionKey)
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
func (t *MessageTool) HasSentInRound(sessionKey string) bool {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.sentTargets) > 0
return len(t.sentTargets[sessionKey]) > 0
}
// HasSentTo returns true if the message tool sent to the specific channel+chatID
// during the current round. Used by PublishResponseIfNeeded to avoid suppressing
// the final response when the message tool only sent to a different conversation.
func (t *MessageTool) HasSentTo(channel, chatID string) bool {
func (t *MessageTool) HasSentTo(sessionKey, channel, chatID string) bool {
t.mu.Lock()
defer t.mu.Unlock()
for _, st := range t.sentTargets {
for _, st := range t.sentTargets[sessionKey] {
if st.Channel == channel && st.ChatID == chatID {
return true
}
@@ -123,8 +130,9 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
}
}
sessionKey := ToolSessionKey(ctx)
t.mu.Lock()
t.sentTargets = append(t.sentTargets, sentTarget{Channel: channel, ChatID: chatID})
t.sentTargets[sessionKey] = append(t.sentTargets[sessionKey], sentTarget{Channel: channel, ChatID: chatID})
t.mu.Unlock()
// Silent: user already received the message directly