diff --git a/pkg/tools/hardware/serial.go b/pkg/tools/hardware/serial.go index 3f55da362..17f15ca94 100644 --- a/pkg/tools/hardware/serial.go +++ b/pkg/tools/hardware/serial.go @@ -4,6 +4,9 @@ import ( "context" "encoding/json" "fmt" + "math" + "regexp" + "runtime" "strings" "time" "unicode/utf8" @@ -18,6 +21,15 @@ const ( maxSerialReadBytes = 4096 ) +var ( + unixSerialPortPattern = regexp.MustCompile(`^(?:/dev/)?(?:ttyS\d+|ttyUSB\d+|ttyACM\d+|ttyAMA\d+|rfcomm\d+|tty\.[A-Za-z0-9._-]+|cu\.[A-Za-z0-9._-]+)$`) + windowsSerialPortPattern = regexp.MustCompile(`^(?:\\\\\.\\)?COM[1-9]\d*$`) + unixSerialBaudRates = map[int]struct{}{ + 50: {}, 75: {}, 110: {}, 134: {}, 150: {}, 200: {}, 300: {}, 600: {}, 1200: {}, 1800: {}, + 2400: {}, 4800: {}, 9600: {}, 19200: {}, 38400: {}, 57600: {}, 115200: {}, 230400: {}, + } +) + type SerialTool struct{} type serialPortInfo struct { @@ -60,7 +72,7 @@ func (t *SerialTool) Parameters() map[string]any { }, "baud": map[string]any{ "type": "integer", - "description": "Baud rate. Default: 115200.", + "description": "Baud rate. Default: 115200. Linux/macOS currently support standard termios rates up to 230400; Windows accepts configured rates up to 4000000.", }, "data_bits": map[string]any{ "type": "integer", @@ -209,8 +221,13 @@ func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) { return serialConfig{}, ErrorResult("port is required (for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3)") } + normalizedPort, err := normalizeSerialPort(port) + if err != nil { + return serialConfig{}, ErrorResult(err.Error()) + } + cfg := serialConfig{ - Port: port, + Port: normalizedPort, Baud: defaultSerialBaud, DataBits: defaultSerialDataBits, Parity: "none", @@ -220,8 +237,8 @@ func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) { if v, ok := args["baud"].(float64); ok { cfg.Baud = int(v) } - if cfg.Baud < 50 || cfg.Baud > 4000000 { - return serialConfig{}, ErrorResult("baud must be between 50 and 4000000") + if err := validateSerialBaud(cfg.Baud); err != nil { + return serialConfig{}, ErrorResult(err.Error()) } if v, ok := args["data_bits"].(float64); ok { @@ -288,6 +305,9 @@ func parseSerialWritePayload(args map[string]any) ([]byte, *ToolResult) { if !ok { return nil, ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) } + if f != math.Trunc(f) { + return nil, ErrorResult(fmt.Sprintf("data[%d] is not an integer byte value", i)) + } b := int(f) if b < 0 || b > 255 { return nil, ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) @@ -330,3 +350,56 @@ func serialPayloadSummary(data []byte) map[string]any { } return summary } + +func normalizeSerialPort(port string) (string, error) { + switch runtime.GOOS { + case "windows": + return normalizeWindowsSerialPath(port) + case "linux", "darwin": + return normalizeUnixSerialPath(port) + default: + if normalized, err := normalizeUnixSerialPath(port); err == nil { + return normalized, nil + } + return normalizeWindowsSerialPath(port) + } +} + +func normalizeUnixSerialPath(port string) (string, error) { + trimmed := strings.TrimSpace(port) + if !unixSerialPortPattern.MatchString(trimmed) { + return "", fmt.Errorf( + "invalid serial port: expected a safe Unix device name such as /dev/ttyUSB0 or /dev/cu.usbserial-0001", + ) + } + if strings.HasPrefix(trimmed, "/dev/") { + return trimmed, nil + } + return "/dev/" + trimmed, nil +} + +func normalizeWindowsSerialPath(port string) (string, error) { + trimmed := strings.ToUpper(strings.TrimSpace(port)) + if !windowsSerialPortPattern.MatchString(trimmed) { + return "", fmt.Errorf("invalid serial port: expected a COM port such as COM3") + } + if strings.HasPrefix(trimmed, `\\.\`) { + return trimmed, nil + } + return `\\.\` + trimmed, nil +} + +func validateSerialBaud(baud int) error { + if baud < 50 || baud > 4000000 { + return fmt.Errorf("baud must be between 50 and 4000000") + } + + switch runtime.GOOS { + case "linux", "darwin": + if _, ok := unixSerialBaudRates[baud]; !ok { + return fmt.Errorf("unsupported baud rate on this platform: %d (supported up to 230400)", baud) + } + } + + return nil +} diff --git a/pkg/tools/hardware/serial_other.go b/pkg/tools/hardware/serial_other.go index f57ce2fa3..c0f8573ee 100644 --- a/pkg/tools/hardware/serial_other.go +++ b/pkg/tools/hardware/serial_other.go @@ -8,7 +8,7 @@ import ( ) func serialListPorts() ([]serialPortInfo, error) { - return nil, nil + return nil, fmt.Errorf("serial is not supported on this platform") } func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) { diff --git a/pkg/tools/hardware/serial_other_test.go b/pkg/tools/hardware/serial_other_test.go new file mode 100644 index 000000000..ef04c4062 --- /dev/null +++ b/pkg/tools/hardware/serial_other_test.go @@ -0,0 +1,18 @@ +//go:build !linux && !darwin && !windows + +package hardwaretools + +import ( + "strings" + "testing" +) + +func TestSerialListPortsUnsupportedPlatform(t *testing.T) { + _, err := serialListPorts() + if err == nil { + t.Fatal("expected unsupported platform error") + } + if !strings.Contains(err.Error(), "not supported") { + t.Fatalf("serialListPorts() error = %v, want unsupported platform message", err) + } +} diff --git a/pkg/tools/hardware/serial_test.go b/pkg/tools/hardware/serial_test.go index 68657dffd..a9e1de6b0 100644 --- a/pkg/tools/hardware/serial_test.go +++ b/pkg/tools/hardware/serial_test.go @@ -1,13 +1,20 @@ package hardwaretools import ( + "runtime" + "strings" "testing" "time" ) func TestParseSerialConfig(t *testing.T) { + port := "/dev/ttyUSB0" + if runtime.GOOS == "windows" { + port = "COM3" + } + cfg, errResult := parseSerialConfig(map[string]any{ - "port": "COM3", + "port": port, "baud": float64(9600), "data_bits": float64(7), "parity": "even", @@ -17,14 +24,23 @@ func TestParseSerialConfig(t *testing.T) { t.Fatalf("parseSerialConfig() unexpected error = %v", errResult.ForLLM) } - if cfg.Port != "COM3" || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 { + wantPort := "/dev/ttyUSB0" + if runtime.GOOS == "windows" { + wantPort = `\\.\COM3` + } + if cfg.Port != wantPort || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 { t.Fatalf("parseSerialConfig() = %#v", cfg) } } func TestParseSerialConfigRejectsInvalidParity(t *testing.T) { + port := "/dev/ttyUSB0" + if runtime.GOOS == "windows" { + port = "COM3" + } + _, errResult := parseSerialConfig(map[string]any{ - "port": "/dev/ttyUSB0", + "port": port, "parity": "mark", }) if errResult == nil { @@ -32,6 +48,152 @@ func TestParseSerialConfigRejectsInvalidParity(t *testing.T) { } } +func TestParseSerialConfigRejectsUnsupportedUnixBaud(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skip("Unix baud validation only applies on Unix platforms") + } + + _, errResult := parseSerialConfig(map[string]any{ + "port": "/dev/ttyUSB0", + "baud": float64(460800), + }) + if errResult == nil { + t.Fatal("expected unsupported Unix baud rate to fail") + } +} + +func TestParseSerialWritePayloadRejectsFractionalBytes(t *testing.T) { + _, errResult := parseSerialWritePayload(map[string]any{ + "data": []any{65.9}, + }) + if errResult == nil { + t.Fatal("expected fractional byte value to fail") + } +} + +func TestValidateSerialBaud(t *testing.T) { + tests := []struct { + name string + baud int + wantErr bool + }{ + {name: "default-supported", baud: 115200}, + {name: "max-unix-supported", baud: 230400}, + {name: "too-low", baud: 49, wantErr: true}, + {name: "too-high", baud: 4000001, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateSerialBaud(tt.baud) + if (err != nil) != tt.wantErr { + t.Fatalf("validateSerialBaud(%d) error = %v, wantErr %v", tt.baud, err, tt.wantErr) + } + }) + } +} + +func TestParseSerialConfigRejectsUnsafePortPaths(t *testing.T) { + tests := []string{ + "../../../etc/passwd", + "/etc/passwd", + `C:\temp\device.txt`, + `\\.\C:\temp\device.txt`, + } + + for _, port := range tests { + t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) { + _, errResult := parseSerialConfig(map[string]any{ + "port": port, + }) + if errResult == nil { + t.Fatalf("expected unsafe port %q to be rejected", port) + } + }) + } +} + +func TestNormalizeUnixSerialPath(t *testing.T) { + tests := []struct { + port string + want string + }{ + {port: "ttyUSB0", want: "/dev/ttyUSB0"}, + {port: "/dev/ttyACM0", want: "/dev/ttyACM0"}, + {port: "/dev/cu.usbserial-0001", want: "/dev/cu.usbserial-0001"}, + } + + for _, tt := range tests { + got, err := normalizeUnixSerialPath(tt.port) + if err != nil { + t.Fatalf("normalizeUnixSerialPath(%q) unexpected error = %v", tt.port, err) + } + if got != tt.want { + t.Fatalf("normalizeUnixSerialPath(%q) = %q, want %q", tt.port, got, tt.want) + } + } +} + +func TestNormalizeUnixSerialPathRejectsInvalidNames(t *testing.T) { + tests := []string{ + "", + "ttyUSB0/../../passwd", + "/dev/../../etc/passwd", + "/tmp/ttyUSB0", + "ttyUSB", + "COM3", + } + + for _, port := range tests { + t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) { + if _, err := normalizeUnixSerialPath(port); err == nil { + t.Fatalf("expected %q to be rejected", port) + } + }) + } +} + +func TestNormalizeWindowsSerialPath(t *testing.T) { + tests := []struct { + port string + want string + }{ + {port: "COM3", want: `\\.\COM3`}, + {port: "com12", want: `\\.\COM12`}, + {port: `\\.\COM7`, want: `\\.\COM7`}, + } + + for _, tt := range tests { + got, err := normalizeWindowsSerialPath(tt.port) + if err != nil { + t.Fatalf("normalizeWindowsSerialPath(%q) unexpected error = %v", tt.port, err) + } + if got != tt.want { + t.Fatalf("normalizeWindowsSerialPath(%q) = %q, want %q", tt.port, got, tt.want) + } + } +} + +func TestNormalizeWindowsSerialPathRejectsInvalidNames(t *testing.T) { + tests := []string{ + "", + "COM0", + "COM", + "/dev/ttyUSB0", + `C:\temp\device.txt`, + `\\.\C:\temp\device.txt`, + `\\server\share\COM3`, + } + + for _, port := range tests { + t.Run(strings.ReplaceAll(strings.ReplaceAll(port, `\`, "_"), "/", "_"), func(t *testing.T) { + if _, err := normalizeWindowsSerialPath(port); err == nil { + t.Fatalf("expected %q to be rejected", port) + } + }) + } +} + func TestParseSerialTimeout(t *testing.T) { timeout, errResult := parseSerialTimeout(map[string]any{ "timeout_ms": float64(2500), diff --git a/pkg/tools/hardware/serial_unix.go b/pkg/tools/hardware/serial_unix.go index caef7e0e8..d5094b208 100644 --- a/pkg/tools/hardware/serial_unix.go +++ b/pkg/tools/hardware/serial_unix.go @@ -7,7 +7,6 @@ import ( "os" "path/filepath" "sort" - "strings" "time" "golang.org/x/sys/unix" @@ -112,7 +111,7 @@ func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, err } func openAndConfigureSerialPort(cfg serialConfig) (int, error) { - fd, err := unix.Open(normalizeUnixSerialPath(cfg.Port), unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0) + fd, err := unix.Open(cfg.Port, unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0) if err != nil { return -1, err } @@ -253,11 +252,3 @@ func durationToPollTimeout(timeout time.Duration) int { } return ms } - -func normalizeUnixSerialPath(port string) string { - trimmed := strings.TrimSpace(port) - if strings.HasPrefix(trimmed, "/dev/") { - return trimmed - } - return "/dev/" + trimmed -} diff --git a/pkg/tools/hardware/serial_windows.go b/pkg/tools/hardware/serial_windows.go index c4d119ac4..a109b8025 100644 --- a/pkg/tools/hardware/serial_windows.go +++ b/pkg/tools/hardware/serial_windows.go @@ -4,7 +4,6 @@ package hardwaretools import ( "fmt" - "path/filepath" "sort" "strings" "time" @@ -126,9 +125,8 @@ func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, err } func openAndConfigureWindowsSerial(cfg serialConfig, timeout time.Duration) (windows.Handle, error) { - path := normalizeWindowsSerialPath(cfg.Port) handle, err := windows.CreateFile( - windows.StringToUTF16Ptr(path), + windows.StringToUTF16Ptr(cfg.Port), windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, @@ -201,14 +199,3 @@ func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout procPurgeComm.Call(uintptr(handle), uintptr(purgeRxClear|purgeTxClear)) return nil } - -func normalizeWindowsSerialPath(port string) string { - trimmed := strings.ToUpper(strings.TrimSpace(port)) - if strings.HasPrefix(trimmed, `\\.\`) { - return trimmed - } - if filepath.VolumeName(trimmed) != "" { - return trimmed - } - return `\\.\` + trimmed -}