mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): support btw side questions (#2532)
This commit is contained in:
+488
-8
@@ -405,7 +405,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
al.drainBusToSteering(ctx, activeScope, activeAgentID)
|
||||
al.drainBusToSteering(ctx, ctx, activeScope, activeAgentID)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -566,12 +566,14 @@ func (p *lateSteeringProvider) GetDefaultModel() string {
|
||||
}
|
||||
|
||||
type blockingDirectProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
firstResp string
|
||||
finalResp string
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
secondStarted chan struct{}
|
||||
releaseSecond chan struct{}
|
||||
firstResp string
|
||||
finalResp string
|
||||
}
|
||||
|
||||
func (p *blockingDirectProvider) Chat(
|
||||
@@ -586,11 +588,15 @@ func (p *blockingDirectProvider) Chat(
|
||||
call := p.calls
|
||||
firstStarted := p.firstStarted
|
||||
releaseFirst := p.releaseFirst
|
||||
secondStarted := p.secondStarted
|
||||
releaseSecond := p.releaseSecond
|
||||
firstResp := p.firstResp
|
||||
finalResp := p.finalResp
|
||||
if call == 1 && p.firstStarted != nil {
|
||||
close(p.firstStarted)
|
||||
p.firstStarted = nil
|
||||
}
|
||||
if call == 2 && p.secondStarted != nil {
|
||||
close(p.secondStarted)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
@@ -604,6 +610,14 @@ func (p *blockingDirectProvider) Chat(
|
||||
}
|
||||
|
||||
_ = firstStarted
|
||||
_ = secondStarted
|
||||
if call == 2 && releaseSecond != nil {
|
||||
select {
|
||||
case <-releaseSecond:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return &providers.LLMResponse{Content: finalResp}, nil
|
||||
}
|
||||
|
||||
@@ -611,6 +625,73 @@ func (p *blockingDirectProvider) GetDefaultModel() string {
|
||||
return "blocking-direct-mock"
|
||||
}
|
||||
|
||||
type blockedBtwWithFollowupProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
secondStarted chan struct{}
|
||||
releaseSecond chan struct{}
|
||||
thirdStarted chan struct{}
|
||||
thirdMessages []providers.Message
|
||||
}
|
||||
|
||||
func (p *blockedBtwWithFollowupProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.calls++
|
||||
call := p.calls
|
||||
firstStarted := p.firstStarted
|
||||
releaseFirst := p.releaseFirst
|
||||
secondStarted := p.secondStarted
|
||||
releaseSecond := p.releaseSecond
|
||||
thirdStarted := p.thirdStarted
|
||||
if call == 1 && p.firstStarted != nil {
|
||||
close(p.firstStarted)
|
||||
}
|
||||
if call == 2 && p.secondStarted != nil {
|
||||
close(p.secondStarted)
|
||||
}
|
||||
if call == 3 {
|
||||
p.thirdMessages = append([]providers.Message(nil), messages...)
|
||||
if p.thirdStarted != nil {
|
||||
close(p.thirdStarted)
|
||||
}
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
switch call {
|
||||
case 1:
|
||||
_ = firstStarted
|
||||
select {
|
||||
case <-releaseFirst:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return &providers.LLMResponse{Content: "long turn finished"}, nil
|
||||
case 2:
|
||||
_ = secondStarted
|
||||
select {
|
||||
case <-releaseSecond:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return &providers.LLMResponse{Content: "btw delayed reply"}, nil
|
||||
default:
|
||||
_ = thirdStarted
|
||||
return &providers.LLMResponse{Content: "continued after follow-up"}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *blockedBtwWithFollowupProvider) GetDefaultModel() string {
|
||||
return "blocked-btw-followup-mock"
|
||||
}
|
||||
|
||||
type interruptibleTool struct {
|
||||
name string
|
||||
started chan struct{}
|
||||
@@ -1010,6 +1091,405 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_BtwCommandBypassesQueuedTurn(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
firstResp: "long turn finished",
|
||||
finalResp: "btw immediate reply",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "execute sleep 60, then send OK",
|
||||
}
|
||||
btw := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "/btw what is the current progress?",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first LLM call to start")
|
||||
}
|
||||
|
||||
messageTool, ok := al.GetRegistry().GetDefaultAgent().Tools.Get("message")
|
||||
var mt *tools.MessageTool
|
||||
if !ok {
|
||||
mt = tools.NewMessageTool()
|
||||
al.RegisterTool(mt)
|
||||
} else {
|
||||
var typeOK bool
|
||||
mt, typeOK = messageTool.(*tools.MessageTool)
|
||||
if !typeOK {
|
||||
t.Fatal("expected message tool type")
|
||||
}
|
||||
}
|
||||
mt.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error {
|
||||
return nil
|
||||
})
|
||||
if result := mt.Execute(context.Background(), map[string]any{
|
||||
"channel": "test",
|
||||
"chat_id": "chat1",
|
||||
"content": "already sent from busy turn",
|
||||
}); result == nil || result.IsError {
|
||||
t.Fatalf("message tool setup result = %+v, want successful send", result)
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
|
||||
t.Fatalf("publish /btw inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "btw immediate reply" {
|
||||
t.Fatalf("expected /btw reply before long turn completion, got %q", outbound.Content)
|
||||
}
|
||||
if outbound.AgentID != routing.DefaultAgentID {
|
||||
t.Fatalf("expected /btw outbound agent_id %q, got %q", routing.DefaultAgentID, outbound.AgentID)
|
||||
}
|
||||
route, _, err := al.resolveMessageRoute(btw)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute(/btw) error = %v", err)
|
||||
}
|
||||
expectedSessionKey := resolveScopeKey(al.allocateRouteSession(route, btw).SessionKey, btw.SessionKey)
|
||||
if outbound.SessionKey != expectedSessionKey {
|
||||
t.Fatalf("expected /btw outbound session_key %q, got %q", expectedSessionKey, outbound.SessionKey)
|
||||
}
|
||||
if outbound.Scope == nil ||
|
||||
outbound.Scope.AgentID != routing.DefaultAgentID ||
|
||||
outbound.Scope.Channel != "test" {
|
||||
t.Fatalf(
|
||||
"expected /btw outbound scope for agent %q on test channel, got %+v",
|
||||
routing.DefaultAgentID,
|
||||
outbound.Scope,
|
||||
)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /btw outbound response")
|
||||
}
|
||||
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
|
||||
t.Fatalf("expected /btw to bypass steering queue, got %v", msgs)
|
||||
}
|
||||
|
||||
close(provider.releaseFirst)
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
t.Fatalf("expected busy turn final response to stay suppressed, got %q", outbound.Content)
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
callCount := provider.calls
|
||||
provider.mu.Unlock()
|
||||
if callCount != 2 {
|
||||
t.Fatalf("provider call count = %d, want 2", callCount)
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_BtwCommandSurvivesActiveTurnCompletion(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
secondStarted: make(chan struct{}),
|
||||
releaseSecond: make(chan struct{}),
|
||||
firstResp: "long turn finished",
|
||||
finalResp: "btw delayed reply",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "execute a long turn",
|
||||
}
|
||||
btw := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "/btw can you still answer?",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first LLM call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
|
||||
t.Fatalf("publish /btw inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.secondStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /btw LLM call to start")
|
||||
}
|
||||
|
||||
close(provider.releaseFirst)
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "long turn finished" {
|
||||
t.Fatalf("expected first outbound to be long turn response, got %q", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for long turn response")
|
||||
}
|
||||
|
||||
close(provider.releaseSecond)
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "btw delayed reply" {
|
||||
t.Fatalf("expected /btw response after drain cancellation, got %q", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for delayed /btw response")
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_BlockedBtwDoesNotBlockFollowupContinuation(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := &blockedBtwWithFollowupProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
secondStarted: make(chan struct{}),
|
||||
releaseSecond: make(chan struct{}),
|
||||
thirdStarted: make(chan struct{}),
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
useTestSideQuestionProvider(al, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "execute a long turn",
|
||||
}
|
||||
btw := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "/btw this side question blocks",
|
||||
}
|
||||
followup := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "normal follow-up while btw is blocked",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first LLM call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, btw); err != nil {
|
||||
t.Fatalf("publish /btw inbound: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-provider.secondStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for /btw LLM call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, followup); err != nil {
|
||||
t.Fatalf("publish follow-up inbound: %v", err)
|
||||
}
|
||||
close(provider.releaseFirst)
|
||||
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "continued after follow-up" {
|
||||
t.Fatalf("expected continuation response before /btw release, got %q", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for follow-up continuation response")
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
thirdMessages := append([]providers.Message(nil), provider.thirdMessages...)
|
||||
provider.mu.Unlock()
|
||||
foundFollowup := false
|
||||
for _, msg := range thirdMessages {
|
||||
if msg.Role == "user" && msg.Content == followup.Content {
|
||||
foundFollowup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundFollowup {
|
||||
t.Fatalf("continuation messages did not include follow-up: %+v", thirdMessages)
|
||||
}
|
||||
|
||||
close(provider.releaseSecond)
|
||||
select {
|
||||
case outbound := <-msgBus.OutboundChan():
|
||||
if outbound.Content != "btw delayed reply" {
|
||||
t.Fatalf("expected delayed /btw response, got %q", outbound.Content)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for delayed /btw response")
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user