diff --git a/config/config.example.json b/config/config.example.json index 30460c231..665a6fd11 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -14,7 +14,9 @@ "tool_feedback": { "enabled": false, "max_args_length": 300, - "separate_messages": false + "separate_messages": false, + "pretty_print": true, + "disable_escape_html": true } } }, diff --git a/go.mod b/go.mod index 4afbe9d85..e4326d8ab 100644 --- a/go.mod +++ b/go.mod @@ -122,7 +122,7 @@ require ( github.com/github/copilot-sdk/go v0.2.0 github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/jsonschema-go v0.4.2 // indirect + github.com/google/jsonschema-go v0.4.2 github.com/grbit/go-json v0.11.0 // indirect github.com/klauspost/compress v1.18.4 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect diff --git a/pkg/config/config.go b/pkg/config/config.go index 0cbc6fead..d4908aca3 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -251,7 +251,7 @@ type ToolFeedbackConfig struct { MaxArgsLength int `json:"max_args_length" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_MAX_ARGS_LENGTH"` SeparateMessages bool `json:"separate_messages" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_SEPARATE_MESSAGES"` PrettyPrint bool `json:"pretty_print" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_PRETTY_PRINT"` - DisableEscapeHTML bool `json:"disable_escape_html" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_DISABLE_ESCAPE_HTML"` + DisableEscapeHTML bool `json:"disable_escape_html" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_DISABLE_ESCAPE_HTML"` } type AgentDefaults struct { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 7725b040e..32e3cbbe1 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -39,7 +39,7 @@ func DefaultConfig() *Config { MaxArgsLength: 300, SeparateMessages: false, PrettyPrint: true, - DisableEscapeHTML: true, + DisableEscapeHTML: true, }, SplitOnMarker: false, }, diff --git a/tool_feedback.go b/tool_feedback.go new file mode 100644 index 000000000..b8cf3c384 --- /dev/null +++ b/tool_feedback.go @@ -0,0 +1,86 @@ +package utils + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" +) + +const ToolFeedbackContinuationHint = "Continuing the current task." + +func FormatArgsJSON(args map[string]any, prettyPrint, disableEscapeHTML bool) string { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + if prettyPrint { + enc.SetIndent("", " ") + } + if disableEscapeHTML { + enc.SetEscapeHTML(false) + } + if err := enc.Encode(args); err != nil { + return "{}" + } + return strings.TrimSpace(buf.String()) +} + +func FormatToolFeedbackMessage(toolName, explanation, argsPreview string) string { + toolName = strings.TrimSpace(toolName) + explanation = strings.TrimSpace(explanation) + argsPreview = strings.TrimSpace(argsPreview) + + bodyLines := make([]string, 0, 2) + if explanation != "" { + bodyLines = append(bodyLines, explanation) + } + if argsPreview != "" { + bodyLines = append(bodyLines, "```json\n"+argsPreview+"\n```") + } + body := strings.Join(bodyLines, "\n") + + if toolName == "" { + return body + } + if body == "" { + return fmt.Sprintf("\U0001f527 `%s`", toolName) + } + + return fmt.Sprintf("\U0001f527 `%s`\n%s", toolName, body) +} + +func FitToolFeedbackMessage(content string, maxLen int) string { + content = strings.TrimSpace(content) + if content == "" || maxLen <= 0 { + return "" + } + if len([]rune(content)) <= maxLen { + return content + } + + firstLine, rest, hasRest := strings.Cut(content, "\n") + firstLine = strings.TrimSpace(firstLine) + rest = strings.TrimSpace(rest) + + if !hasRest || rest == "" { + return Truncate(firstLine, maxLen) + } + + if len([]rune(firstLine)) >= maxLen { + return Truncate(firstLine, maxLen) + } + + remaining := maxLen - len([]rune(firstLine)) - 1 + if remaining <= 0 { + return Truncate(firstLine, maxLen) + } + + return firstLine + "\n" + Truncate(rest, remaining) +} + +func Truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + return string(runes[:maxLen]) +} diff --git a/tool_feedback_test.go b/tool_feedback_test.go new file mode 100644 index 000000000..ef76c1506 --- /dev/null +++ b/tool_feedback_test.go @@ -0,0 +1,101 @@ +package utils + +import ( + "encoding/json" + "testing" +) + +func TestFormatArgsJSON_Defaults(t *testing.T) { + args := map[string]any{"path": "README.md", "line": 42} + got := FormatArgsJSON(args, false, false) + var gotVal, wantVal any + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + want := `{"path":"README.md","line":42}` + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() = %q, want %q", got, want) + } +} + +func TestFormatArgsJSON_PrettyPrint(t *testing.T) { + args := map[string]any{"path": "README.md", "line": 42} + got := FormatArgsJSON(args, true, false) + var gotVal any + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + want := `{"path":"README.md","line":42}` + var wantVal any + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() prettyPrint = %q, want structure %q", got, want) + } +} + +func TestFormatArgsJSON_DisableEscapeHTML(t *testing.T) { + args := map[string]any{"msg": "a < b && c > d"} + got := FormatArgsJSON(args, false, true) + var gotVal, wantVal any + want := `{"msg":"a < b && c > d"}` + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() disableEscapeHTML = %q, want %q", got, want) + } +} + +func TestFormatArgsJSON_PrettyPrintAndDisableEscapeHTML(t *testing.T) { + args := map[string]any{"msg": "a < b && c > d"} + got := FormatArgsJSON(args, true, true) + var gotVal, wantVal any + want := `{"msg":"a < b && c > d"}` + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() combined = %q, want %q", got, want) + } +} + +func TestFormatArgsJSON_EscapeHTMLByDefault(t *testing.T) { + args := map[string]any{"msg": "a < b && c > d"} + got := FormatArgsJSON(args, false, false) + var gotVal, wantVal any + want := `{"msg":"a \u003c b \u0026\u0026 c \u003e d"}` + if err := json.Unmarshal([]byte(got), &gotVal); err != nil { + t.Fatalf("FormatArgsJSON() returned invalid JSON: %v", err) + } + if err := json.Unmarshal([]byte(want), &wantVal); err != nil { + t.Fatalf("invalid test want JSON: %v", err) + } + if !jsonValEq(gotVal, wantVal) { + t.Fatalf("FormatArgsJSON() default escape = %q, want %q", got, want) + } +} + +func TestFormatArgsJSON_NilArgs(t *testing.T) { + got := FormatArgsJSON(nil, false, false) + want := `null` + if got != want { + t.Fatalf("FormatArgsJSON() nil = %q, want %q", got, want) + } +} + +func jsonValEq(a, b any) bool { + aJSON, _ := json.Marshal(a) + bJSON, _ := json.Marshal(b) + return string(aJSON) == string(bJSON) +}