diff --git a/pkg/agent/agent_utils.go b/pkg/agent/agent_utils.go index bbfb3f2ae..2651fb2db 100644 --- a/pkg/agent/agent_utils.go +++ b/pkg/agent/agent_utils.go @@ -4,7 +4,6 @@ package agent import ( "context" - "encoding/json" "fmt" "maps" "path/filepath" @@ -171,15 +170,8 @@ func toolFeedbackExplanationFromMessages(messages []providers.Message) string { } func toolFeedbackArgsPreview(args map[string]any, maxLen int) string { - if args == nil { - args = map[string]any{} - } - - argsJSON, err := json.MarshalIndent(args, "", " ") - if err != nil { - return utils.Truncate(fmt.Sprintf("%v", args), maxLen) - } - return utils.Truncate(string(argsJSON), maxLen) + argsJSON := utils.FormatArgsJSON(args, true, false) + return utils.Truncate(argsJSON, maxLen) } func shouldPublishToolFeedback(cfg *config.Config, ts *turnState) bool { diff --git a/pkg/utils/tool_feedback.go b/pkg/utils/tool_feedback.go index de7cb467e..1834d7f78 100644 --- a/pkg/utils/tool_feedback.go +++ b/pkg/utils/tool_feedback.go @@ -1,12 +1,35 @@ package utils import ( + "bytes" + "encoding/json" "fmt" "strings" ) const ToolFeedbackContinuationHint = "Continuing the current task." +func FormatArgsJSON(args map[string]any, prettyPrint, disableEscapeHTML bool) string { + // Normalize nil to empty map for consistent output + if args == nil { + args = map[string]any{} + } + + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + if prettyPrint { + enc.SetIndent("", " ") + } + if disableEscapeHTML { + enc.SetEscapeHTML(false) + } + if err := enc.Encode(args); err != nil { + // Fallback to fmt.Sprintf to preserve visibility of problematic args + return fmt.Sprintf("%v", args) + } + return strings.TrimSpace(buf.String()) +} + // FormatToolFeedbackMessage renders a tool feedback message for chat channels. // It keeps the tool name on the first line for animation and can include both // a human explanation and the serialized tool arguments in the body. diff --git a/pkg/utils/tool_feedback_test.go b/pkg/utils/tool_feedback_test.go index c30f53827..da4accce4 100644 --- a/pkg/utils/tool_feedback_test.go +++ b/pkg/utils/tool_feedback_test.go @@ -1,6 +1,9 @@ package utils -import "testing" +import ( + "encoding/json" + "testing" +) func TestFormatToolFeedbackMessage(t *testing.T) { got := FormatToolFeedbackMessage( @@ -56,3 +59,98 @@ func TestFitToolFeedbackMessage_TruncatesSingleLineMessage(t *testing.T) { t.Fatalf("FitToolFeedbackMessage() = %q, want %q", got, want) } } + +func TestFormatArgsJSON_Defaults(t *testing.T) { + args := map[string]any{"path": "README.md", "line": 42} + got := FormatArgsJSON(args, false, false) + var gotVal, wantVal any + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + want := `{"path":"README.md","line":42}` + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() = %q, want %q", got, want) + } +} + +func TestFormatArgsJSON_PrettyPrint(t *testing.T) { + args := map[string]any{"path": "README.md", "line": 42} + got := FormatArgsJSON(args, true, false) + var gotVal any + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + want := `{"path":"README.md","line":42}` + var wantVal any + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() prettyPrint = %q, want structure %q", got, want) + } +} + +func TestFormatArgsJSON_DisableEscapeHTML(t *testing.T) { + args := map[string]any{"msg": "a < b && c > d"} + got := FormatArgsJSON(args, false, true) + var gotVal, wantVal any + want := `{"msg":"a < b && c > d"}` + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() disableEscapeHTML = %q, want %q", got, want) + } +} + +func TestFormatArgsJSON_PrettyPrintAndDisableEscapeHTML(t *testing.T) { + args := map[string]any{"msg": "a < b && c > d"} + got := FormatArgsJSON(args, true, true) + var gotVal, wantVal any + want := `{"msg":"a < b && c > d"}` + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() combined = %q, want %q", got, want) + } +} + +func TestFormatArgsJSON_EscapeHTMLByDefault(t *testing.T) { + args := map[string]any{"msg": "a < b && c > d"} + got := FormatArgsJSON(args, false, false) + var gotVal, wantVal any + want := `{"msg":"a \u003c b \u0026\u0026 c \u003e d"}` + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() default escape = %q, want %q", got, want) + } +} + +func TestFormatArgsJSON_NilArgs(t *testing.T) { + got := FormatArgsJSON(nil, false, false) + want := `{}` + if got != want { + t.Fatalf("FormatArgsJSON() nil = %q, want %q", got, want) + } +} + +func jsonValEq(a, b any) bool { + aJSON, _ := json.Marshal(a) + bJSON, _ := json.Marshal(b) + return string(aJSON) == string(bJSON) +}