fix(agent): preserve prompt hook and cache semantics

This commit is contained in:
Hoshina
2026-04-25 01:25:17 +08:00
parent 48d8952591
commit 9ca73b944f
4 changed files with 144 additions and 1 deletions
+27 -1
View File
@@ -409,6 +409,7 @@ func systemMessageFingerprints(messages []providers.Message) []systemMessageFing
if msg.Role != "system" {
continue
}
msg = providerVisibleMessage(msg)
fingerprints = append(fingerprints, systemMessageFingerprint{
Index: i,
Message: cloneProviderMessages([]providers.Message{msg})[0],
@@ -418,7 +419,32 @@ func systemMessageFingerprints(messages []providers.Message) []systemMessageFing
}
func llmHookToolDefinitionsUnchanged(before, after []providers.ToolDefinition) bool {
return reflect.DeepEqual(cloneToolDefinitions(before), cloneToolDefinitions(after))
return reflect.DeepEqual(providerVisibleToolDefinitions(before), providerVisibleToolDefinitions(after))
}
func providerVisibleMessage(msg providers.Message) providers.Message {
msg.PromptLayer = ""
msg.PromptSlot = ""
msg.PromptSource = ""
if len(msg.SystemParts) > 0 {
msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
for i := range msg.SystemParts {
msg.SystemParts[i].PromptLayer = ""
msg.SystemParts[i].PromptSlot = ""
msg.SystemParts[i].PromptSource = ""
}
}
return msg
}
func providerVisibleToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
cloned := cloneToolDefinitions(defs)
for i := range cloned {
cloned[i].PromptLayer = ""
cloned[i].PromptSlot = ""
cloned[i].PromptSource = ""
}
return cloned
}
func (hm *HookManager) BeforeTool(
+93
View File
@@ -2,6 +2,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"os"
"strings"
@@ -186,6 +187,46 @@ func (h *llmUserAppendHook) AfterLLM(
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
}
type llmJSONRoundTripUserAppendHook struct{}
type jsonRoundTripLLMHookRequest struct {
Model string `json:"model"`
Messages []providers.Message `json:"messages,omitempty"`
Tools []providers.ToolDefinition `json:"tools,omitempty"`
}
func (h *llmJSONRoundTripUserAppendHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
payload := jsonRoundTripLLMHookRequest{
Model: req.Model,
Messages: req.Messages,
Tools: req.Tools,
}
data, err := json.Marshal(payload)
if err != nil {
return nil, HookDecision{}, err
}
var decoded jsonRoundTripLLMHookRequest
if err := json.Unmarshal(data, &decoded); err != nil {
return nil, HookDecision{}, err
}
next := req.Clone()
next.Model = decoded.Model
next.Messages = decoded.Messages
next.Tools = decoded.Tools
next.Messages = append(next.Messages, providers.Message{Role: "user", Content: "json extra user context"})
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *llmJSONRoundTripUserAppendHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
}
type llmToolRewriteHook struct{}
func (h *llmToolRewriteHook) BeforeLLM(
@@ -274,6 +315,58 @@ func TestHookManager_BeforeLLMAllowsNonSystemMessageMutation(t *testing.T) {
}
}
func TestHookManager_BeforeLLMAllowsJSONRoundTripNonSystemMessageMutation(t *testing.T) {
hm := NewHookManager(nil)
if err := hm.Mount(NamedHook("json-append-user", &llmJSONRoundTripUserAppendHook{})); err != nil {
t.Fatalf("Mount() error = %v", err)
}
req := &LLMHookRequest{
Model: "model",
Messages: []providers.Message{
{
Role: "system",
Content: "system",
PromptLayer: string(PromptLayerKernel),
PromptSlot: string(PromptSlotIdentity),
PromptSource: string(PromptSourceKernel),
SystemParts: []providers.ContentBlock{
{
Type: "text",
Text: "system",
CacheControl: &providers.CacheControl{Type: "ephemeral"},
PromptLayer: string(PromptLayerKernel),
PromptSlot: string(PromptSlotIdentity),
PromptSource: string(PromptSourceKernel),
},
},
},
{Role: "user", Content: "hello"},
},
Tools: []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "mcp_github_create_issue",
Description: "create issue",
Parameters: map[string]any{"type": "object"},
},
PromptLayer: string(PromptLayerCapability),
PromptSlot: string(PromptSlotMCP),
PromptSource: "mcp:github",
},
},
}
got, _ := hm.BeforeLLM(context.Background(), req)
if len(got.Messages) != 3 {
t.Fatalf("messages len = %d, want 3", len(got.Messages))
}
if got.Messages[2].Role != "user" || got.Messages[2].Content != "json extra user context" {
t.Fatalf("appended message = %#v, want json extra user context", got.Messages[2])
}
}
func TestHookManager_BeforeLLMControlsToolDefinitionMutation(t *testing.T) {
hm := NewHookManager(nil)
if err := hm.Mount(NamedHook("rewrite-tool", &llmToolRewriteHook{})); err != nil {
+12
View File
@@ -135,8 +135,14 @@ func TestBuildMessagesFromPrompt_AttachesInternalPromptMetadata(t *testing.T) {
switch part.PromptSource {
case string(PromptSourceRuntime):
hasRuntime = true
if part.CacheControl != nil {
t.Fatalf("runtime cache control = %#v, want nil", part.CacheControl)
}
case string(PromptSourceSummary):
hasSummary = true
if part.CacheControl != nil {
t.Fatalf("summary cache control = %#v, want nil", part.CacheControl)
}
}
}
if !hasRuntime {
@@ -181,6 +187,9 @@ func TestContextBuilder_CollectsToolDiscoveryContributor(t *testing.T) {
if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotTooling) {
t.Fatalf("tool discovery metadata = %#v, want capability/tooling", part)
}
if part.CacheControl == nil || part.CacheControl.Type != "ephemeral" {
t.Fatalf("tool discovery cache control = %#v, want ephemeral", part.CacheControl)
}
}
}
if !found {
@@ -213,6 +222,9 @@ func TestContextBuilder_CollectsMCPServerContributor(t *testing.T) {
if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotMCP) {
t.Fatalf("mcp metadata = %#v, want capability/mcp", part)
}
if part.CacheControl == nil || part.CacheControl.Type != "ephemeral" {
t.Fatalf("mcp cache control = %#v, want ephemeral", part.CacheControl)
}
}
}
if !found {
+12
View File
@@ -49,6 +49,9 @@ func promptOverlaysForOptions(opts processOptions) []PromptPart {
}
func promptContentBlock(part PromptPart, cache *providers.CacheControl) providers.ContentBlock {
if cache == nil {
cache = cacheControlForPromptPart(part)
}
return providers.ContentBlock{
Type: "text",
Text: part.Content,
@@ -59,6 +62,15 @@ func promptContentBlock(part PromptPart, cache *providers.CacheControl) provider
}
}
func cacheControlForPromptPart(part PromptPart) *providers.CacheControl {
switch part.Cache {
case PromptCacheEphemeral:
return &providers.CacheControl{Type: "ephemeral"}
default:
return nil
}
}
func promptMessageWithMetadata(
msg providers.Message,
layer PromptLayer,