From fcc20ec72ccc2f9c413aa46239c6ab21ac976e28 Mon Sep 17 00:00:00 2001 From: Sabyasachi Patra Date: Tue, 24 Mar 2026 16:05:56 +0530 Subject: [PATCH] feat(tools): add tool argument schema validation before execution (#1877) Validate tool call arguments against each tool's Parameters() JSON Schema in ExecuteWithContext() before calling Execute(). This prevents type confusion, argument injection, and missing-field errors from reaching tools. Validates: required fields, type matching (string/integer/number/boolean/ array/object), enum membership, nested objects (recursive), array element types. Rejects unexpected extra properties unless additionalProperties is set to true (for MCP tool compatibility). Returns ToolResult{IsError: true} on failure so the LLM can self-correct. Ref: Security Hardening > Tool abuse prevention via strict parameter validation --- .gitignore | 1 + pkg/tools/registry.go | 8 + pkg/tools/validate.go | 209 +++++++++++++++++ pkg/tools/validate_test.go | 465 +++++++++++++++++++++++++++++++++++++ 4 files changed, 683 insertions(+) create mode 100644 pkg/tools/validate.go create mode 100644 pkg/tools/validate_test.go diff --git a/.gitignore b/.gitignore index 8b5f95215..72f3b1761 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ tasks/ # Plans docs/plans/ +docs/superpowers/ # Editors .vscode/ diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index ed373a28f..2c634e673 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -180,6 +180,14 @@ func (r *ToolRegistry) ExecuteWithContext( return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } + // Validate arguments against the tool's declared schema. + if err := validateToolArgs(tool.Parameters(), args); err != nil { + logger.WarnCF("tool", "Tool argument validation failed", + map[string]any{"tool": name, "error": err.Error()}) + return ErrorResult(fmt.Sprintf("invalid arguments for tool %q: %s", name, err)). + WithError(fmt.Errorf("argument validation failed: %w", err)) + } + // Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx). // Always inject — tools validate what they require. ctx = WithToolContext(ctx, channel, chatID) diff --git a/pkg/tools/validate.go b/pkg/tools/validate.go new file mode 100644 index 000000000..940344708 --- /dev/null +++ b/pkg/tools/validate.go @@ -0,0 +1,209 @@ +package tools + +import ( + "fmt" + "math" +) + +// validateToolArgs validates args against a JSON Schema-like map. +// schema is expected to have optional keys: "properties", "required", "additionalProperties". +func validateToolArgs(schema map[string]any, args map[string]any) error { + if len(schema) == 0 { + return nil + } + + if args == nil { + args = map[string]any{} + } + + if err := checkRequired(schema, args); err != nil { + return err + } + + propsRaw, ok := schema["properties"] + if !ok { + return nil // no properties defined — accept any args + } + + props, ok := propsRaw.(map[string]any) + if !ok { + return nil + } + + additional := allowsAdditional(schema) + + for key, val := range args { + propSchemaRaw, known := props[key] + if !known { + if !additional { + return fmt.Errorf("unexpected property %q", key) + } + continue + } + propSchema, ok := propSchemaRaw.(map[string]any) + if !ok { + continue // can't validate without a proper schema map + } + if err := checkType(key, val, propSchema); err != nil { + return err + } + } + + return nil +} + +// checkRequired verifies that every field listed in schema["required"] is present in args. +func checkRequired(schema map[string]any, args map[string]any) error { + reqRaw, ok := schema["required"] + if !ok { + return nil + } + + var required []string + + switch r := reqRaw.(type) { + case []string: + required = r + case []any: + for _, v := range r { + s, ok := v.(string) + if ok { + required = append(required, s) + } + } + default: + return nil + } + + for _, field := range required { + if _, present := args[field]; !present { + return fmt.Errorf("missing required property %q", field) + } + } + return nil +} + +// allowsAdditional returns true when the schema explicitly sets +// "additionalProperties" to true, or when the key is absent (default: reject extras). +func allowsAdditional(schema map[string]any) bool { + v, ok := schema["additionalProperties"] + if !ok { + return false + } + b, ok := v.(bool) + return ok && b +} + +// checkType validates that val matches the JSON Schema type declared in propSchema. +func checkType(key string, val any, propSchema map[string]any) error { + typeRaw, ok := propSchema["type"] + if !ok { + return nil // no type constraint + } + typeName, ok := typeRaw.(string) + if !ok { + return nil + } + + switch typeName { + case "string": + if _, ok := val.(string); !ok { + return fmt.Errorf("property %q: expected string, got %T", key, val) + } + case "integer": + switch v := val.(type) { + case float64: + if v != math.Trunc(v) { + return fmt.Errorf("property %q: expected integer, got float64 with fractional part", key) + } + case int: + // ok + case int64: + // ok + default: + return fmt.Errorf("property %q: expected integer, got %T", key, val) + } + case "number": + switch val.(type) { + case float64, int, int64: + // ok + default: + return fmt.Errorf("property %q: expected number, got %T", key, val) + } + case "boolean": + if _, ok := val.(bool); !ok { + return fmt.Errorf("property %q: expected boolean, got %T", key, val) + } + case "array": + arr, ok := val.([]any) + if !ok { + return fmt.Errorf("property %q: expected array, got %T", key, val) + } + if err := checkArrayItems(key, arr, propSchema); err != nil { + return err + } + case "object": + obj, ok := val.(map[string]any) + if !ok { + return fmt.Errorf("property %q: expected object, got %T", key, val) + } + if err := validateToolArgs(propSchema, obj); err != nil { + return fmt.Errorf("property %q: %w", key, err) + } + } + + if err := checkEnum(key, val, propSchema); err != nil { + return err + } + + return nil +} + +// checkArrayItems validates each element of arr against the "items" sub-schema. +func checkArrayItems(key string, arr []any, propSchema map[string]any) error { + itemsRaw, ok := propSchema["items"] + if !ok { + return nil + } + itemSchema, ok := itemsRaw.(map[string]any) + if !ok { + return nil + } + for i, elem := range arr { + elemKey := fmt.Sprintf("%s[%d]", key, i) + if err := checkType(elemKey, elem, itemSchema); err != nil { + return err + } + } + return nil +} + +// checkEnum validates that val is one of the allowed enum values in propSchema. +func checkEnum(key string, val any, propSchema map[string]any) error { + enumRaw, ok := propSchema["enum"] + if !ok { + return nil + } + + switch ev := enumRaw.(type) { + case []any: + for _, allowed := range ev { + if val == allowed { + return nil + } + } + case []string: + s, ok := val.(string) + if ok { + for _, allowed := range ev { + if s == allowed { + return nil + } + } + } + default: + return nil // unknown enum format, skip + } + + return fmt.Errorf("property %q: value %v is not in enum", key, val) +} diff --git a/pkg/tools/validate_test.go b/pkg/tools/validate_test.go new file mode 100644 index 000000000..e7f4f619a --- /dev/null +++ b/pkg/tools/validate_test.go @@ -0,0 +1,465 @@ +package tools + +import ( + "context" + "strings" + "testing" +) + +// Ensure imports are used. +var ( + _ = context.Background + _ = strings.Contains +) + +func TestValidateToolArgs(t *testing.T) { + baseSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + "required": []string{"name"}, + } + + tests := []struct { + name string + schema map[string]any + args map[string]any + wantErr string // empty means no error expected + }{ + { + name: "valid args all required present", + schema: baseSchema, + args: map[string]any{"name": "alice", "age": float64(30)}, + }, + { + name: "missing required field", + schema: baseSchema, + args: map[string]any{"age": float64(30)}, + wantErr: "missing required property \"name\"", + }, + { + name: "wrong type string field gets number", + schema: baseSchema, + args: map[string]any{"name": float64(42)}, + wantErr: "expected string", + }, + { + name: "nil args with required fields", + schema: baseSchema, + args: nil, + wantErr: "missing required property \"name\"", + }, + { + name: "nil args no required fields", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + args: nil, + }, + { + name: "empty args no required fields", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + args: map[string]any{}, + }, + { + name: "optional field correct type", + schema: baseSchema, + args: map[string]any{"name": "bob", "age": float64(25)}, + }, + { + name: "optional field wrong type", + schema: baseSchema, + args: map[string]any{"name": "bob", "age": "twenty"}, + wantErr: "expected integer", + }, + { + name: "integer as float64 no fractional part", + schema: baseSchema, + args: map[string]any{"name": "carol", "age": float64(42)}, + }, + { + name: "actual float for integer field", + schema: baseSchema, + args: map[string]any{"name": "dave", "age": float64(42.5)}, + wantErr: "expected integer, got float64 with fractional part", + }, + { + name: "number type accepts float", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "score": map[string]any{"type": "number"}, + }, + }, + args: map[string]any{"score": float64(3.14)}, + }, + { + name: "number type accepts integer", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "score": map[string]any{"type": "number"}, + }, + }, + args: map[string]any{"score": float64(10)}, + }, + { + name: "boolean type valid", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean"}, + }, + }, + args: map[string]any{"flag": true}, + }, + { + name: "boolean type wrong", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean"}, + }, + }, + args: map[string]any{"flag": "true"}, + wantErr: "expected boolean", + }, + { + name: "required as []any from MCP deserialization", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "cmd": map[string]any{"type": "string"}, + }, + "required": []any{"cmd"}, + }, + args: map[string]any{}, + wantErr: "missing required property \"cmd\"", + }, + { + name: "enum valid value []any", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "red"}, + }, + { + name: "enum invalid value []any", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []any{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "yellow"}, + wantErr: "not in enum", + }, + { + name: "enum valid value []string", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "green"}, + }, + { + name: "enum invalid value []string", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "color": map[string]any{"type": "string", "enum": []string{"red", "green", "blue"}}, + }, + }, + args: map[string]any{"color": "yellow"}, + wantErr: "not in enum", + }, + { + name: "extra unexpected property rejected", + schema: baseSchema, + args: map[string]any{"name": "eve", "hobby": "chess"}, + wantErr: "unexpected property \"hobby\"", + }, + { + name: "extra property allowed with additionalProperties true", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + "additionalProperties": true, + }, + args: map[string]any{"name": "eve", "hobby": "chess"}, + }, + { + name: "nested object valid", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "address": map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + "required": []string{"city"}, + }, + }, + }, + args: map[string]any{ + "address": map[string]any{"city": "Berlin"}, + }, + }, + { + name: "nested object wrong type", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "address": map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + }, + }, + }, + args: map[string]any{"address": "not an object"}, + wantErr: "expected object", + }, + { + name: "array with valid element types", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + }, + }, + }, + args: map[string]any{"tags": []any{"a", "b", "c"}}, + }, + { + name: "array with wrong element types", + schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "tags": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + }, + }, + }, + args: map[string]any{"tags": []any{"a", float64(2)}}, + wantErr: "expected string", + }, + { + name: "schema with no properties key accepts any args", + schema: map[string]any{ + "type": "object", + }, + args: map[string]any{"anything": "goes"}, + }, + { + name: "empty schema accepts anything", + schema: map[string]any{}, + args: map[string]any{"foo": "bar"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateToolArgs(tc.schema, tc.args) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tc.wantErr, err) + } + }) + } +} + +func TestValidateToolArgs_RegistryIntegration(t *testing.T) { + r := NewToolRegistry() + r.Register(&mockRegistryTool{ + name: "read_file", + desc: "reads a file", + params: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []string{"path"}, + }, + result: SilentResult("file contents"), + }) + + // Valid args — should succeed + result := r.Execute(context.Background(), "read_file", map[string]any{"path": "/tmp/x"}) + if result.IsError { + t.Errorf("expected success, got error: %s", result.ForLLM) + } + + // Missing required field — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{}) + if !result.IsError { + t.Error("expected validation error for missing required field") + } + if !strings.Contains(result.ForLLM, "missing required p") { + t.Errorf("expected 'missing required p...' in error, got %q", result.ForLLM) + } + if result.Err == nil { + t.Error("expected Err to be set via WithError") + } + + // Wrong type — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{"path": 123.0}) + if !result.IsError { + t.Error("expected validation error for wrong type") + } + if !strings.Contains(result.ForLLM, "expected string") { + t.Errorf("expected 'expected string' in error, got %q", result.ForLLM) + } + + // Extra property — should fail with validation error + result = r.Execute(context.Background(), "read_file", map[string]any{"path": "/x", "__inject": true}) + if !result.IsError { + t.Error("expected validation error for extra property") + } + if !strings.Contains(result.ForLLM, "unexpected prop") { + t.Errorf("expected 'unexpected prop...' in error, got %q", result.ForLLM) + } +} + +func TestValidateToolArgs_RealSchemas(t *testing.T) { + execSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{"type": "string"}, + "working_dir": map[string]any{"type": "string"}, + }, + "required": []string{"command"}, + } + + cronSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []any{"add", "list", "remove", "enable", "disable"}, + }, + }, + "required": []string{"action"}, + } + + webSearchSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + "count": map[string]any{"type": "integer"}, + }, + "required": []string{"query"}, + } + + tests := []struct { + name string + schema map[string]any + args map[string]any + wantErr string + }{ + // ExecTool + { + name: "exec valid args", + schema: execSchema, + args: map[string]any{"command": "ls -la", "working_dir": "/tmp"}, + }, + { + name: "exec missing required command", + schema: execSchema, + args: map[string]any{"working_dir": "/tmp"}, + wantErr: "missing required property \"command\"", + }, + { + name: "exec wrong type for command", + schema: execSchema, + args: map[string]any{"command": float64(123)}, + wantErr: "expected string", + }, + { + name: "exec extra injected arg", + schema: execSchema, + args: map[string]any{"command": "ls", "malicious": "payload"}, + wantErr: "unexpected property \"malicious\"", + }, + + // CronTool + { + name: "cron valid enum value", + schema: cronSchema, + args: map[string]any{"action": "add"}, + }, + { + name: "cron invalid enum value", + schema: cronSchema, + args: map[string]any{"action": "destroy"}, + wantErr: "not in enum", + }, + + // WebSearchTool + { + name: "websearch valid args", + schema: webSearchSchema, + args: map[string]any{"query": "golang testing", "count": float64(10)}, + }, + { + name: "websearch missing required query", + schema: webSearchSchema, + args: map[string]any{"count": float64(5)}, + wantErr: "missing required property \"query\"", + }, + { + name: "websearch wrong type for count", + schema: webSearchSchema, + args: map[string]any{"query": "test", "count": "ten"}, + wantErr: "expected integer", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateToolArgs(tc.schema, tc.args) + if tc.wantErr == "" { + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + return + } + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tc.wantErr, err) + } + }) + } +}