diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go index 69336b89f..2d60e8120 100644 --- a/pkg/agent/hooks.go +++ b/pkg/agent/hooks.go @@ -409,6 +409,7 @@ func systemMessageFingerprints(messages []providers.Message) []systemMessageFing if msg.Role != "system" { continue } + msg = providerVisibleMessage(msg) fingerprints = append(fingerprints, systemMessageFingerprint{ Index: i, Message: cloneProviderMessages([]providers.Message{msg})[0], @@ -418,7 +419,32 @@ func systemMessageFingerprints(messages []providers.Message) []systemMessageFing } func llmHookToolDefinitionsUnchanged(before, after []providers.ToolDefinition) bool { - return reflect.DeepEqual(cloneToolDefinitions(before), cloneToolDefinitions(after)) + return reflect.DeepEqual(providerVisibleToolDefinitions(before), providerVisibleToolDefinitions(after)) +} + +func providerVisibleMessage(msg providers.Message) providers.Message { + msg.PromptLayer = "" + msg.PromptSlot = "" + msg.PromptSource = "" + if len(msg.SystemParts) > 0 { + msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...) + for i := range msg.SystemParts { + msg.SystemParts[i].PromptLayer = "" + msg.SystemParts[i].PromptSlot = "" + msg.SystemParts[i].PromptSource = "" + } + } + return msg +} + +func providerVisibleToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition { + cloned := cloneToolDefinitions(defs) + for i := range cloned { + cloned[i].PromptLayer = "" + cloned[i].PromptSlot = "" + cloned[i].PromptSource = "" + } + return cloned } func (hm *HookManager) BeforeTool( diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go index a936296dd..65400e4b8 100644 --- a/pkg/agent/hooks_test.go +++ b/pkg/agent/hooks_test.go @@ -2,6 +2,7 @@ package agent import ( "context" + "encoding/json" "errors" "os" "strings" @@ -186,6 +187,46 @@ func (h *llmUserAppendHook) AfterLLM( return resp.Clone(), HookDecision{Action: HookActionContinue}, nil } +type llmJSONRoundTripUserAppendHook struct{} + +type jsonRoundTripLLMHookRequest struct { + Model string `json:"model"` + Messages []providers.Message `json:"messages,omitempty"` + Tools []providers.ToolDefinition `json:"tools,omitempty"` +} + +func (h *llmJSONRoundTripUserAppendHook) BeforeLLM( + ctx context.Context, + req *LLMHookRequest, +) (*LLMHookRequest, HookDecision, error) { + payload := jsonRoundTripLLMHookRequest{ + Model: req.Model, + Messages: req.Messages, + Tools: req.Tools, + } + data, err := json.Marshal(payload) + if err != nil { + return nil, HookDecision{}, err + } + var decoded jsonRoundTripLLMHookRequest + if err := json.Unmarshal(data, &decoded); err != nil { + return nil, HookDecision{}, err + } + next := req.Clone() + next.Model = decoded.Model + next.Messages = decoded.Messages + next.Tools = decoded.Tools + next.Messages = append(next.Messages, providers.Message{Role: "user", Content: "json extra user context"}) + return next, HookDecision{Action: HookActionModify}, nil +} + +func (h *llmJSONRoundTripUserAppendHook) AfterLLM( + ctx context.Context, + resp *LLMHookResponse, +) (*LLMHookResponse, HookDecision, error) { + return resp.Clone(), HookDecision{Action: HookActionContinue}, nil +} + type llmToolRewriteHook struct{} func (h *llmToolRewriteHook) BeforeLLM( @@ -274,6 +315,58 @@ func TestHookManager_BeforeLLMAllowsNonSystemMessageMutation(t *testing.T) { } } +func TestHookManager_BeforeLLMAllowsJSONRoundTripNonSystemMessageMutation(t *testing.T) { + hm := NewHookManager(nil) + if err := hm.Mount(NamedHook("json-append-user", &llmJSONRoundTripUserAppendHook{})); err != nil { + t.Fatalf("Mount() error = %v", err) + } + + req := &LLMHookRequest{ + Model: "model", + Messages: []providers.Message{ + { + Role: "system", + Content: "system", + PromptLayer: string(PromptLayerKernel), + PromptSlot: string(PromptSlotIdentity), + PromptSource: string(PromptSourceKernel), + SystemParts: []providers.ContentBlock{ + { + Type: "text", + Text: "system", + CacheControl: &providers.CacheControl{Type: "ephemeral"}, + PromptLayer: string(PromptLayerKernel), + PromptSlot: string(PromptSlotIdentity), + PromptSource: string(PromptSourceKernel), + }, + }, + }, + {Role: "user", Content: "hello"}, + }, + Tools: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "mcp_github_create_issue", + Description: "create issue", + Parameters: map[string]any{"type": "object"}, + }, + PromptLayer: string(PromptLayerCapability), + PromptSlot: string(PromptSlotMCP), + PromptSource: "mcp:github", + }, + }, + } + + got, _ := hm.BeforeLLM(context.Background(), req) + if len(got.Messages) != 3 { + t.Fatalf("messages len = %d, want 3", len(got.Messages)) + } + if got.Messages[2].Role != "user" || got.Messages[2].Content != "json extra user context" { + t.Fatalf("appended message = %#v, want json extra user context", got.Messages[2]) + } +} + func TestHookManager_BeforeLLMControlsToolDefinitionMutation(t *testing.T) { hm := NewHookManager(nil) if err := hm.Mount(NamedHook("rewrite-tool", &llmToolRewriteHook{})); err != nil { diff --git a/pkg/agent/prompt_test.go b/pkg/agent/prompt_test.go index b3f609610..b76b0040d 100644 --- a/pkg/agent/prompt_test.go +++ b/pkg/agent/prompt_test.go @@ -135,8 +135,14 @@ func TestBuildMessagesFromPrompt_AttachesInternalPromptMetadata(t *testing.T) { switch part.PromptSource { case string(PromptSourceRuntime): hasRuntime = true + if part.CacheControl != nil { + t.Fatalf("runtime cache control = %#v, want nil", part.CacheControl) + } case string(PromptSourceSummary): hasSummary = true + if part.CacheControl != nil { + t.Fatalf("summary cache control = %#v, want nil", part.CacheControl) + } } } if !hasRuntime { @@ -181,6 +187,9 @@ func TestContextBuilder_CollectsToolDiscoveryContributor(t *testing.T) { if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotTooling) { t.Fatalf("tool discovery metadata = %#v, want capability/tooling", part) } + if part.CacheControl == nil || part.CacheControl.Type != "ephemeral" { + t.Fatalf("tool discovery cache control = %#v, want ephemeral", part.CacheControl) + } } } if !found { @@ -213,6 +222,9 @@ func TestContextBuilder_CollectsMCPServerContributor(t *testing.T) { if part.PromptLayer != string(PromptLayerCapability) || part.PromptSlot != string(PromptSlotMCP) { t.Fatalf("mcp metadata = %#v, want capability/mcp", part) } + if part.CacheControl == nil || part.CacheControl.Type != "ephemeral" { + t.Fatalf("mcp cache control = %#v, want ephemeral", part.CacheControl) + } } } if !found { diff --git a/pkg/agent/prompt_turn.go b/pkg/agent/prompt_turn.go index 7b7d295c1..588a8f00f 100644 --- a/pkg/agent/prompt_turn.go +++ b/pkg/agent/prompt_turn.go @@ -49,6 +49,9 @@ func promptOverlaysForOptions(opts processOptions) []PromptPart { } func promptContentBlock(part PromptPart, cache *providers.CacheControl) providers.ContentBlock { + if cache == nil { + cache = cacheControlForPromptPart(part) + } return providers.ContentBlock{ Type: "text", Text: part.Content, @@ -59,6 +62,15 @@ func promptContentBlock(part PromptPart, cache *providers.CacheControl) provider } } +func cacheControlForPromptPart(part PromptPart) *providers.CacheControl { + switch part.Cache { + case PromptCacheEphemeral: + return &providers.CacheControl{Type: "ephemeral"} + default: + return nil + } +} + func promptMessageWithMetadata( msg providers.Message, layer PromptLayer,