mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
hardware/serial: tighten validation and error handling
This commit is contained in:
@@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -18,6 +21,15 @@ const (
|
||||
maxSerialReadBytes = 4096
|
||||
)
|
||||
|
||||
var (
|
||||
unixSerialPortPattern = regexp.MustCompile(`^(?:/dev/)?(?:ttyS\d+|ttyUSB\d+|ttyACM\d+|ttyAMA\d+|rfcomm\d+|tty\.[A-Za-z0-9._-]+|cu\.[A-Za-z0-9._-]+)$`)
|
||||
windowsSerialPortPattern = regexp.MustCompile(`^(?:\\\\\.\\)?COM[1-9]\d*$`)
|
||||
unixSerialBaudRates = map[int]struct{}{
|
||||
50: {}, 75: {}, 110: {}, 134: {}, 150: {}, 200: {}, 300: {}, 600: {}, 1200: {}, 1800: {},
|
||||
2400: {}, 4800: {}, 9600: {}, 19200: {}, 38400: {}, 57600: {}, 115200: {}, 230400: {},
|
||||
}
|
||||
)
|
||||
|
||||
type SerialTool struct{}
|
||||
|
||||
type serialPortInfo struct {
|
||||
@@ -60,7 +72,7 @@ func (t *SerialTool) Parameters() map[string]any {
|
||||
},
|
||||
"baud": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Baud rate. Default: 115200.",
|
||||
"description": "Baud rate. Default: 115200. Linux/macOS currently support standard termios rates up to 230400; Windows accepts configured rates up to 4000000.",
|
||||
},
|
||||
"data_bits": map[string]any{
|
||||
"type": "integer",
|
||||
@@ -209,8 +221,13 @@ func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) {
|
||||
return serialConfig{}, ErrorResult("port is required (for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3)")
|
||||
}
|
||||
|
||||
normalizedPort, err := normalizeSerialPort(port)
|
||||
if err != nil {
|
||||
return serialConfig{}, ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
cfg := serialConfig{
|
||||
Port: port,
|
||||
Port: normalizedPort,
|
||||
Baud: defaultSerialBaud,
|
||||
DataBits: defaultSerialDataBits,
|
||||
Parity: "none",
|
||||
@@ -220,8 +237,8 @@ func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) {
|
||||
if v, ok := args["baud"].(float64); ok {
|
||||
cfg.Baud = int(v)
|
||||
}
|
||||
if cfg.Baud < 50 || cfg.Baud > 4000000 {
|
||||
return serialConfig{}, ErrorResult("baud must be between 50 and 4000000")
|
||||
if err := validateSerialBaud(cfg.Baud); err != nil {
|
||||
return serialConfig{}, ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if v, ok := args["data_bits"].(float64); ok {
|
||||
@@ -288,6 +305,9 @@ func parseSerialWritePayload(args map[string]any) ([]byte, *ToolResult) {
|
||||
if !ok {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i))
|
||||
}
|
||||
if f != math.Trunc(f) {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] is not an integer byte value", i))
|
||||
}
|
||||
b := int(f)
|
||||
if b < 0 || b > 255 {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b))
|
||||
@@ -330,3 +350,56 @@ func serialPayloadSummary(data []byte) map[string]any {
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func normalizeSerialPort(port string) (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return normalizeWindowsSerialPath(port)
|
||||
case "linux", "darwin":
|
||||
return normalizeUnixSerialPath(port)
|
||||
default:
|
||||
if normalized, err := normalizeUnixSerialPath(port); err == nil {
|
||||
return normalized, nil
|
||||
}
|
||||
return normalizeWindowsSerialPath(port)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeUnixSerialPath(port string) (string, error) {
|
||||
trimmed := strings.TrimSpace(port)
|
||||
if !unixSerialPortPattern.MatchString(trimmed) {
|
||||
return "", fmt.Errorf(
|
||||
"invalid serial port: expected a safe Unix device name such as /dev/ttyUSB0 or /dev/cu.usbserial-0001",
|
||||
)
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "/dev/") {
|
||||
return trimmed, nil
|
||||
}
|
||||
return "/dev/" + trimmed, nil
|
||||
}
|
||||
|
||||
func normalizeWindowsSerialPath(port string) (string, error) {
|
||||
trimmed := strings.ToUpper(strings.TrimSpace(port))
|
||||
if !windowsSerialPortPattern.MatchString(trimmed) {
|
||||
return "", fmt.Errorf("invalid serial port: expected a COM port such as COM3")
|
||||
}
|
||||
if strings.HasPrefix(trimmed, `\\.\`) {
|
||||
return trimmed, nil
|
||||
}
|
||||
return `\\.\` + trimmed, nil
|
||||
}
|
||||
|
||||
func validateSerialBaud(baud int) error {
|
||||
if baud < 50 || baud > 4000000 {
|
||||
return fmt.Errorf("baud must be between 50 and 4000000")
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "darwin":
|
||||
if _, ok := unixSerialBaudRates[baud]; !ok {
|
||||
return fmt.Errorf("unsupported baud rate on this platform: %d (supported up to 230400)", baud)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func serialListPorts() ([]serialPortInfo, error) {
|
||||
return nil, nil
|
||||
return nil, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
|
||||
func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
//go:build !linux && !darwin && !windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSerialListPortsUnsupportedPlatform(t *testing.T) {
|
||||
_, err := serialListPorts()
|
||||
if err == nil {
|
||||
t.Fatal("expected unsupported platform error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not supported") {
|
||||
t.Fatalf("serialListPorts() error = %v, want unsupported platform message", err)
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,20 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseSerialConfig(t *testing.T) {
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
cfg, errResult := parseSerialConfig(map[string]any{
|
||||
"port": "COM3",
|
||||
"port": port,
|
||||
"baud": float64(9600),
|
||||
"data_bits": float64(7),
|
||||
"parity": "even",
|
||||
@@ -17,14 +24,23 @@ func TestParseSerialConfig(t *testing.T) {
|
||||
t.Fatalf("parseSerialConfig() unexpected error = %v", errResult.ForLLM)
|
||||
}
|
||||
|
||||
if cfg.Port != "COM3" || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 {
|
||||
wantPort := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
wantPort = `\\.\COM3`
|
||||
}
|
||||
if cfg.Port != wantPort || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 {
|
||||
t.Fatalf("parseSerialConfig() = %#v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsInvalidParity(t *testing.T) {
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": "/dev/ttyUSB0",
|
||||
"port": port,
|
||||
"parity": "mark",
|
||||
})
|
||||
if errResult == nil {
|
||||
@@ -32,6 +48,152 @@ func TestParseSerialConfigRejectsInvalidParity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsUnsupportedUnixBaud(t *testing.T) {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
|
||||
t.Skip("Unix baud validation only applies on Unix platforms")
|
||||
}
|
||||
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": "/dev/ttyUSB0",
|
||||
"baud": float64(460800),
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected unsupported Unix baud rate to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialWritePayloadRejectsFractionalBytes(t *testing.T) {
|
||||
_, errResult := parseSerialWritePayload(map[string]any{
|
||||
"data": []any{65.9},
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected fractional byte value to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSerialBaud(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baud int
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "default-supported", baud: 115200},
|
||||
{name: "max-unix-supported", baud: 230400},
|
||||
{name: "too-low", baud: 49, wantErr: true},
|
||||
{name: "too-high", baud: 4000001, wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateSerialBaud(tt.baud)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("validateSerialBaud(%d) error = %v, wantErr %v", tt.baud, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsUnsafePortPaths(t *testing.T) {
|
||||
tests := []string{
|
||||
"../../../etc/passwd",
|
||||
"/etc/passwd",
|
||||
`C:\temp\device.txt`,
|
||||
`\\.\C:\temp\device.txt`,
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) {
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": port,
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatalf("expected unsafe port %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUnixSerialPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
port string
|
||||
want string
|
||||
}{
|
||||
{port: "ttyUSB0", want: "/dev/ttyUSB0"},
|
||||
{port: "/dev/ttyACM0", want: "/dev/ttyACM0"},
|
||||
{port: "/dev/cu.usbserial-0001", want: "/dev/cu.usbserial-0001"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := normalizeUnixSerialPath(tt.port)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeUnixSerialPath(%q) unexpected error = %v", tt.port, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("normalizeUnixSerialPath(%q) = %q, want %q", tt.port, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUnixSerialPathRejectsInvalidNames(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"ttyUSB0/../../passwd",
|
||||
"/dev/../../etc/passwd",
|
||||
"/tmp/ttyUSB0",
|
||||
"ttyUSB",
|
||||
"COM3",
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) {
|
||||
if _, err := normalizeUnixSerialPath(port); err == nil {
|
||||
t.Fatalf("expected %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeWindowsSerialPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
port string
|
||||
want string
|
||||
}{
|
||||
{port: "COM3", want: `\\.\COM3`},
|
||||
{port: "com12", want: `\\.\COM12`},
|
||||
{port: `\\.\COM7`, want: `\\.\COM7`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := normalizeWindowsSerialPath(tt.port)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeWindowsSerialPath(%q) unexpected error = %v", tt.port, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("normalizeWindowsSerialPath(%q) = %q, want %q", tt.port, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeWindowsSerialPathRejectsInvalidNames(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"COM0",
|
||||
"COM",
|
||||
"/dev/ttyUSB0",
|
||||
`C:\temp\device.txt`,
|
||||
`\\.\C:\temp\device.txt`,
|
||||
`\\server\share\COM3`,
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(strings.ReplaceAll(port, `\`, "_"), "/", "_"), func(t *testing.T) {
|
||||
if _, err := normalizeWindowsSerialPath(port); err == nil {
|
||||
t.Fatalf("expected %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialTimeout(t *testing.T) {
|
||||
timeout, errResult := parseSerialTimeout(map[string]any{
|
||||
"timeout_ms": float64(2500),
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -112,7 +111,7 @@ func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, err
|
||||
}
|
||||
|
||||
func openAndConfigureSerialPort(cfg serialConfig) (int, error) {
|
||||
fd, err := unix.Open(normalizeUnixSerialPath(cfg.Port), unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0)
|
||||
fd, err := unix.Open(cfg.Port, unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
@@ -253,11 +252,3 @@ func durationToPollTimeout(timeout time.Duration) int {
|
||||
}
|
||||
return ms
|
||||
}
|
||||
|
||||
func normalizeUnixSerialPath(port string) string {
|
||||
trimmed := strings.TrimSpace(port)
|
||||
if strings.HasPrefix(trimmed, "/dev/") {
|
||||
return trimmed
|
||||
}
|
||||
return "/dev/" + trimmed
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ package hardwaretools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -126,9 +125,8 @@ func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, err
|
||||
}
|
||||
|
||||
func openAndConfigureWindowsSerial(cfg serialConfig, timeout time.Duration) (windows.Handle, error) {
|
||||
path := normalizeWindowsSerialPath(cfg.Port)
|
||||
handle, err := windows.CreateFile(
|
||||
windows.StringToUTF16Ptr(path),
|
||||
windows.StringToUTF16Ptr(cfg.Port),
|
||||
windows.GENERIC_READ|windows.GENERIC_WRITE,
|
||||
0,
|
||||
nil,
|
||||
@@ -201,14 +199,3 @@ func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout
|
||||
procPurgeComm.Call(uintptr(handle), uintptr(purgeRxClear|purgeTxClear))
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeWindowsSerialPath(port string) string {
|
||||
trimmed := strings.ToUpper(strings.TrimSpace(port))
|
||||
if strings.HasPrefix(trimmed, `\\.\`) {
|
||||
return trimmed
|
||||
}
|
||||
if filepath.VolumeName(trimmed) != "" {
|
||||
return trimmed
|
||||
}
|
||||
return `\\.\` + trimmed
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user