feat(agent): support btw side questions (#2532)

This commit is contained in:
lxowalle
2026-04-16 10:53:09 +08:00
committed by GitHub
parent a8d0b03515
commit e22b4e1eee
23 changed files with 1737 additions and 70 deletions
+488 -8
View File
@@ -405,7 +405,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
done := make(chan struct{})
go func() {
al.drainBusToSteering(ctx, activeScope, activeAgentID)
al.drainBusToSteering(ctx, ctx, activeScope, activeAgentID)
close(done)
}()
@@ -566,12 +566,14 @@ func (p *lateSteeringProvider) GetDefaultModel() string {
}
type blockingDirectProvider struct {
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
firstResp string
finalResp string
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
secondStarted chan struct{}
releaseSecond chan struct{}
firstResp string
finalResp string
}
func (p *blockingDirectProvider) Chat(
@@ -586,11 +588,15 @@ func (p *blockingDirectProvider) Chat(
call := p.calls
firstStarted := p.firstStarted
releaseFirst := p.releaseFirst
secondStarted := p.secondStarted
releaseSecond := p.releaseSecond
firstResp := p.firstResp
finalResp := p.finalResp
if call == 1 && p.firstStarted != nil {
close(p.firstStarted)
p.firstStarted = nil
}
if call == 2 && p.secondStarted != nil {
close(p.secondStarted)
}
p.mu.Unlock()
@@ -604,6 +610,14 @@ func (p *blockingDirectProvider) Chat(
}
_ = firstStarted
_ = secondStarted
if call == 2 && releaseSecond != nil {
select {
case <-releaseSecond:
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &providers.LLMResponse{Content: finalResp}, nil
}
@@ -611,6 +625,73 @@ func (p *blockingDirectProvider) GetDefaultModel() string {
return "blocking-direct-mock"
}
type blockedBtwWithFollowupProvider struct {
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
secondStarted chan struct{}
releaseSecond chan struct{}
thirdStarted chan struct{}
thirdMessages []providers.Message
}
func (p *blockedBtwWithFollowupProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.calls++
call := p.calls
firstStarted := p.firstStarted
releaseFirst := p.releaseFirst
secondStarted := p.secondStarted
releaseSecond := p.releaseSecond
thirdStarted := p.thirdStarted
if call == 1 && p.firstStarted != nil {
close(p.firstStarted)
}
if call == 2 && p.secondStarted != nil {
close(p.secondStarted)
}
if call == 3 {
p.thirdMessages = append([]providers.Message(nil), messages...)
if p.thirdStarted != nil {
close(p.thirdStarted)
}
}
p.mu.Unlock()
switch call {
case 1:
_ = firstStarted
select {
case <-releaseFirst:
case <-ctx.Done():
return nil, ctx.Err()
}
return &providers.LLMResponse{Content: "long turn finished"}, nil
case 2:
_ = secondStarted
select {
case <-releaseSecond:
case <-ctx.Done():
return nil, ctx.Err()
}
return &providers.LLMResponse{Content: "btw delayed reply"}, nil
default:
_ = thirdStarted
return &providers.LLMResponse{Content: "continued after follow-up"}, nil
}
}
func (p *blockedBtwWithFollowupProvider) GetDefaultModel() string {
return "blocked-btw-followup-mock"
}
type interruptibleTool struct {
name string
started chan struct{}
@@ -1010,6 +1091,405 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
}
}
func TestAgentLoop_Steering_BtwCommandBypassesQueuedTurn(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
firstResp: "long turn finished",
finalResp: "btw immediate reply",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute sleep 60, then send OK",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw what is the current progress?",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
messageTool, ok := al.GetRegistry().GetDefaultAgent().Tools.Get("message")
var mt *tools.MessageTool
if !ok {
mt = tools.NewMessageTool()
al.RegisterTool(mt)
} else {
var typeOK bool
mt, typeOK = messageTool.(*tools.MessageTool)
if !typeOK {
t.Fatal("expected message tool type")
}
}
mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
return nil
})
if result := mt.Execute(context.Background(), map[string]any{
"channel": "test",
"chat_id": "chat1",
"content": "already sent from busy turn",
}); result == nil || result.IsError {
t.Fatalf("message tool setup result = %+v, want successful send", result)
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw immediate reply" {
t.Fatalf("expected /btw reply before long turn completion, got %q", outbound.Content)
}
if outbound.AgentID != routing.DefaultAgentID {
t.Fatalf("expected /btw outbound agent_id %q, got %q", routing.DefaultAgentID, outbound.AgentID)
}
route, _, err := al.resolveMessageRoute(btw)
if err != nil {
t.Fatalf("resolveMessageRoute(/btw) error = %v", err)
}
expectedSessionKey := resolveScopeKey(al.allocateRouteSession(route, btw).SessionKey, btw.SessionKey)
if outbound.SessionKey != expectedSessionKey {
t.Fatalf("expected /btw outbound session_key %q, got %q", expectedSessionKey, outbound.SessionKey)
}
if outbound.Scope == nil ||
outbound.Scope.AgentID != routing.DefaultAgentID ||
outbound.Scope.Channel != "test" {
t.Fatalf(
"expected /btw outbound scope for agent %q on test channel, got %+v",
routing.DefaultAgentID,
outbound.Scope,
)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw outbound response")
}
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
t.Fatalf("expected /btw to bypass steering queue, got %v", msgs)
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
t.Fatalf("expected busy turn final response to stay suppressed, got %q", outbound.Content)
case <-time.After(2 * time.Second):
}
provider.mu.Lock()
callCount := provider.calls
provider.mu.Unlock()
if callCount != 2 {
t.Fatalf("provider call count = %d, want 2", callCount)
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_Steering_BtwCommandSurvivesActiveTurnCompletion(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
secondStarted: make(chan struct{}),
releaseSecond: make(chan struct{}),
firstResp: "long turn finished",
finalResp: "btw delayed reply",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute a long turn",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw can you still answer?",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case <-provider.secondStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw LLM call to start")
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "long turn finished" {
t.Fatalf("expected first outbound to be long turn response, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for long turn response")
}
close(provider.releaseSecond)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw delayed reply" {
t.Fatalf("expected /btw response after drain cancellation, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for delayed /btw response")
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_Steering_BlockedBtwDoesNotBlockFollowupContinuation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &blockedBtwWithFollowupProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
secondStarted: make(chan struct{}),
releaseSecond: make(chan struct{}),
thirdStarted: make(chan struct{}),
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
useTestSideQuestionProvider(al, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "execute a long turn",
}
btw := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "/btw this side question blocks",
}
followup := bus.InboundMessage{
Context: bus.InboundContext{
Channel: "test",
ChatID: "chat1",
ChatType: "direct",
SenderID: "user1",
},
Content: "normal follow-up while btw is blocked",
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
t.Fatalf("publish /btw inbound: %v", err)
}
select {
case <-provider.secondStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for /btw LLM call to start")
}
if err := msgBus.PublishInbound(pubCtx, followup); err != nil {
t.Fatalf("publish follow-up inbound: %v", err)
}
close(provider.releaseFirst)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "continued after follow-up" {
t.Fatalf("expected continuation response before /btw release, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for follow-up continuation response")
}
provider.mu.Lock()
thirdMessages := append([]providers.Message(nil), provider.thirdMessages...)
provider.mu.Unlock()
foundFollowup := false
for _, msg := range thirdMessages {
if msg.Role == "user" && msg.Content == followup.Content {
foundFollowup = true
break
}
}
if !foundFollowup {
t.Fatalf("continuation messages did not include follow-up: %+v", thirdMessages)
}
close(provider.releaseSecond)
select {
case outbound := <-msgBus.OutboundChan():
if outbound.Content != "btw delayed reply" {
t.Fatalf("expected delayed /btw response, got %q", outbound.Content)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for delayed /btw response")
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
}
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {