mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(agent): preserve prompt hook and cache semantics
This commit is contained in:
+27
-1
@@ -409,6 +409,7 @@ func systemMessageFingerprints(messages []providers.Message) []systemMessageFing
|
|||||||
if msg.Role != "system" {
|
if msg.Role != "system" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
msg = providerVisibleMessage(msg)
|
||||||
fingerprints = append(fingerprints, systemMessageFingerprint{
|
fingerprints = append(fingerprints, systemMessageFingerprint{
|
||||||
Index: i,
|
Index: i,
|
||||||
Message: cloneProviderMessages([]providers.Message{msg})[0],
|
Message: cloneProviderMessages([]providers.Message{msg})[0],
|
||||||
@@ -418,7 +419,32 @@ func systemMessageFingerprints(messages []providers.Message) []systemMessageFing
|
|||||||
}
|
}
|
||||||
|
|
||||||
func llmHookToolDefinitionsUnchanged(before, after []providers.ToolDefinition) bool {
|
func llmHookToolDefinitionsUnchanged(before, after []providers.ToolDefinition) bool {
|
||||||
return reflect.DeepEqual(cloneToolDefinitions(before), cloneToolDefinitions(after))
|
return reflect.DeepEqual(providerVisibleToolDefinitions(before), providerVisibleToolDefinitions(after))
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerVisibleMessage(msg providers.Message) providers.Message {
|
||||||
|
msg.PromptLayer = ""
|
||||||
|
msg.PromptSlot = ""
|
||||||
|
msg.PromptSource = ""
|
||||||
|
if len(msg.SystemParts) > 0 {
|
||||||
|
msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
||||||
|
for i := range msg.SystemParts {
|
||||||
|
msg.SystemParts[i].PromptLayer = ""
|
||||||
|
msg.SystemParts[i].PromptSlot = ""
|
||||||
|
msg.SystemParts[i].PromptSource = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerVisibleToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
|
||||||
|
cloned := cloneToolDefinitions(defs)
|
||||||
|
for i := range cloned {
|
||||||
|
cloned[i].PromptLayer = ""
|
||||||
|
cloned[i].PromptSlot = ""
|
||||||
|
cloned[i].PromptSource = ""
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hm *HookManager) BeforeTool(
|
func (hm *HookManager) BeforeTool(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package agent
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -186,6 +187,46 @@ func (h *llmUserAppendHook) AfterLLM(
|
|||||||
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type llmJSONRoundTripUserAppendHook struct{}
|
||||||
|
|
||||||
|
type jsonRoundTripLLMHookRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []providers.Message `json:"messages,omitempty"`
|
||||||
|
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *llmJSONRoundTripUserAppendHook) BeforeLLM(
|
||||||
|
ctx context.Context,
|
||||||
|
req *LLMHookRequest,
|
||||||
|
) (*LLMHookRequest, HookDecision, error) {
|
||||||
|
payload := jsonRoundTripLLMHookRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Messages: req.Messages,
|
||||||
|
Tools: req.Tools,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, HookDecision{}, err
|
||||||
|
}
|
||||||
|
var decoded jsonRoundTripLLMHookRequest
|
||||||
|
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||||
|
return nil, HookDecision{}, err
|
||||||
|
}
|
||||||
|
next := req.Clone()
|
||||||
|
next.Model = decoded.Model
|
||||||
|
next.Messages = decoded.Messages
|
||||||
|
next.Tools = decoded.Tools
|
||||||
|
next.Messages = append(next.Messages, providers.Message{Role: "user", Content: "json extra user context"})
|
||||||
|
return next, HookDecision{Action: HookActionModify}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *llmJSONRoundTripUserAppendHook) AfterLLM(
|
||||||
|
ctx context.Context,
|
||||||
|
resp *LLMHookResponse,
|
||||||
|
) (*LLMHookResponse, HookDecision, error) {
|
||||||
|
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
||||||
|
}
|
||||||
|
|
||||||
type llmToolRewriteHook struct{}
|
type llmToolRewriteHook struct{}
|
||||||
|
|
||||||
func (h *llmToolRewriteHook) BeforeLLM(
|
func (h *llmToolRewriteHook) BeforeLLM(
|
||||||
@@ -274,6 +315,58 @@ func TestHookManager_BeforeLLMAllowsNonSystemMessageMutation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHookManager_BeforeLLMAllowsJSONRoundTripNonSystemMessageMutation(t *testing.T) {
|
||||||
|
hm := NewHookManager(nil)
|
||||||
|
if err := hm.Mount(NamedHook("json-append-user", &llmJSONRoundTripUserAppendHook{})); err != nil {
|
||||||
|
t.Fatalf("Mount() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &LLMHookRequest{
|
||||||
|
Model: "model",
|
||||||
|
Messages: []providers.Message{
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: "system",
|
||||||
|
PromptLayer: string(PromptLayerKernel),
|
||||||
|
PromptSlot: string(PromptSlotIdentity),
|
||||||
|
PromptSource: string(PromptSourceKernel),
|
||||||
|
SystemParts: []providers.ContentBlock{
|
||||||
|
{
|
||||||
|
Type: "text",
|
||||||
|
Text: "system",
|
||||||
|
CacheControl: &providers.CacheControl{Type: "ephemeral"},
|
||||||
|
PromptLayer: string(PromptLayerKernel),
|
||||||
|
PromptSlot: string(PromptSlotIdentity),
|
||||||
|
PromptSource: string(PromptSourceKernel),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "user", Content: "hello"},
|
||||||
|
},
|
||||||
|
Tools: []providers.ToolDefinition{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: providers.ToolFunctionDefinition{
|
||||||
|
Name: "mcp_github_create_issue",
|
||||||
|
Description: "create issue",
|
||||||
|
Parameters: map[string]any{"type": "object"},
|
||||||
|
},
|
||||||
|
PromptLayer: string(PromptLayerCapability),
|
||||||
|
PromptSlot: string(PromptSlotMCP),
|
||||||
|
PromptSource: "mcp:github",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _ := hm.BeforeLLM(context.Background(), req)
|
||||||
|
if len(got.Messages) != 3 {
|
||||||
|
t.Fatalf("messages len = %d, want 3", len(got.Messages))
|
||||||
|
}
|
||||||
|
if got.Messages[2].Role != "user" || got.Messages[2].Content != "json extra user context" {
|
||||||
|
t.Fatalf("appended message = %#v, want json extra user context", got.Messages[2])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHookManager_BeforeLLMControlsToolDefinitionMutation(t *testing.T) {
|
func TestHookManager_BeforeLLMControlsToolDefinitionMutation(t *testing.T) {
|
||||||
hm := NewHookManager(nil)
|
hm := NewHookManager(nil)
|
||||||
if err := hm.Mount(NamedHook("rewrite-tool", &llmToolRewriteHook{})); err != nil {
|
if err := hm.Mount(NamedHook("rewrite-tool", &llmToolRewriteHook{})); err != nil {
|
||||||
|
|||||||
@@ -135,8 +135,14 @@ func TestBuildMessagesFromPrompt_AttachesInternalPromptMetadata(t *testing.T) {
|
|||||||
switch part.PromptSource {
|
switch part.PromptSource {
|
||||||
case string(PromptSourceRuntime):
|
case string(PromptSourceRuntime):
|
||||||
hasRuntime = true
|
hasRuntime = true
|
||||||
|
if part.CacheControl != nil {
|
||||||
|
t.Fatalf("runtime cache control = %#v, want nil", part.CacheControl)
|
||||||
|
}
|
||||||
case string(PromptSourceSummary):
|
case string(PromptSourceSummary):
|
||||||
hasSummary = true
|
hasSummary = true
|
||||||
|
if part.CacheControl != nil {
|
||||||
|
t.Fatalf("summary cache control = %#v, want nil", part.CacheControl)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !hasRuntime {
|
if !hasRuntime {
|
||||||
@@ -181,6 +187,9 @@ func TestContextBuilder_CollectsToolDiscoveryContributor(t *testing.T) {
|
|||||||
if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotTooling) {
|
if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotTooling) {
|
||||||
t.Fatalf("tool discovery metadata = %#v, want capability/tooling", part)
|
t.Fatalf("tool discovery metadata = %#v, want capability/tooling", part)
|
||||||
}
|
}
|
||||||
|
if part.CacheControl == nil || part.CacheControl.Type != "ephemeral" {
|
||||||
|
t.Fatalf("tool discovery cache control = %#v, want ephemeral", part.CacheControl)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
@@ -213,6 +222,9 @@ func TestContextBuilder_CollectsMCPServerContributor(t *testing.T) {
|
|||||||
if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotMCP) {
|
if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotMCP) {
|
||||||
t.Fatalf("mcp metadata = %#v, want capability/mcp", part)
|
t.Fatalf("mcp metadata = %#v, want capability/mcp", part)
|
||||||
}
|
}
|
||||||
|
if part.CacheControl == nil || part.CacheControl.Type != "ephemeral" {
|
||||||
|
t.Fatalf("mcp cache control = %#v, want ephemeral", part.CacheControl)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ func promptOverlaysForOptions(opts processOptions) []PromptPart {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func promptContentBlock(part PromptPart, cache *providers.CacheControl) providers.ContentBlock {
|
func promptContentBlock(part PromptPart, cache *providers.CacheControl) providers.ContentBlock {
|
||||||
|
if cache == nil {
|
||||||
|
cache = cacheControlForPromptPart(part)
|
||||||
|
}
|
||||||
return providers.ContentBlock{
|
return providers.ContentBlock{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: part.Content,
|
Text: part.Content,
|
||||||
@@ -59,6 +62,15 @@ func promptContentBlock(part PromptPart, cache *providers.CacheControl) provider
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cacheControlForPromptPart(part PromptPart) *providers.CacheControl {
|
||||||
|
switch part.Cache {
|
||||||
|
case PromptCacheEphemeral:
|
||||||
|
return &providers.CacheControl{Type: "ephemeral"}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func promptMessageWithMetadata(
|
func promptMessageWithMetadata(
|
||||||
msg providers.Message,
|
msg providers.Message,
|
||||||
layer PromptLayer,
|
layer PromptLayer,
|
||||||
|
|||||||
Reference in New Issue
Block a user