fix(agent): preserve prompt hook and cache semantics

This commit is contained in:
Hoshina
2026-04-25 01:25:17 +08:00
parent 48d8952591
commit 9ca73b944f
4 changed files with 144 additions and 1 deletions
+93
View File
@@ -2,6 +2,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"os"
"strings"
@@ -186,6 +187,46 @@ func (h *llmUserAppendHook) AfterLLM(
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
}
type llmJSONRoundTripUserAppendHook struct{}
type jsonRoundTripLLMHookRequest struct {
Model string `json:"model"`
Messages []providers.Message `json:"messages,omitempty"`
Tools []providers.ToolDefinition `json:"tools,omitempty"`
}
func (h *llmJSONRoundTripUserAppendHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
payload := jsonRoundTripLLMHookRequest{
Model: req.Model,
Messages: req.Messages,
Tools: req.Tools,
}
data, err := json.Marshal(payload)
if err != nil {
return nil, HookDecision{}, err
}
var decoded jsonRoundTripLLMHookRequest
if err := json.Unmarshal(data, &decoded); err != nil {
return nil, HookDecision{}, err
}
next := req.Clone()
next.Model = decoded.Model
next.Messages = decoded.Messages
next.Tools = decoded.Tools
next.Messages = append(next.Messages, providers.Message{Role: "user", Content: "json extra user context"})
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *llmJSONRoundTripUserAppendHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
}
type llmToolRewriteHook struct{}
func (h *llmToolRewriteHook) BeforeLLM(
@@ -274,6 +315,58 @@ func TestHookManager_BeforeLLMAllowsNonSystemMessageMutation(t *testing.T) {
}
}
func TestHookManager_BeforeLLMAllowsJSONRoundTripNonSystemMessageMutation(t *testing.T) {
hm := NewHookManager(nil)
if err := hm.Mount(NamedHook("json-append-user", &llmJSONRoundTripUserAppendHook{})); err != nil {
t.Fatalf("Mount() error = %v", err)
}
req := &LLMHookRequest{
Model: "model",
Messages: []providers.Message{
{
Role: "system",
Content: "system",
PromptLayer: string(PromptLayerKernel),
PromptSlot: string(PromptSlotIdentity),
PromptSource: string(PromptSourceKernel),
SystemParts: []providers.ContentBlock{
{
Type: "text",
Text: "system",
CacheControl: &providers.CacheControl{Type: "ephemeral"},
PromptLayer: string(PromptLayerKernel),
PromptSlot: string(PromptSlotIdentity),
PromptSource: string(PromptSourceKernel),
},
},
},
{Role: "user", Content: "hello"},
},
Tools: []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "mcp_github_create_issue",
Description: "create issue",
Parameters: map[string]any{"type": "object"},
},
PromptLayer: string(PromptLayerCapability),
PromptSlot: string(PromptSlotMCP),
PromptSource: "mcp:github",
},
},
}
got, _ := hm.BeforeLLM(context.Background(), req)
if len(got.Messages) != 3 {
t.Fatalf("messages len = %d, want 3", len(got.Messages))
}
if got.Messages[2].Role != "user" || got.Messages[2].Content != "json extra user context" {
t.Fatalf("appended message = %#v, want json extra user context", got.Messages[2])
}
}
func TestHookManager_BeforeLLMControlsToolDefinitionMutation(t *testing.T) {
hm := NewHookManager(nil)
if err := hm.Mount(NamedHook("rewrite-tool", &llmToolRewriteHook{})); err != nil {