mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Add cross-platform serial tool support (#2673)
* feat(tools): add cross-platform serial hardware tool * feat(config): wire serial tool into runtime and dashboard * hardware/serial: tighten validation and error handling * hardware/serial: improve unix cancellation and timeout polling * hardware/serial: improve windows I/O handling * hardware/serial: fix darwin cross-compilation build * docs(design): summarize hardware support and serial limits * build: keep go generate on host during cross builds * onboard: drop unrelated go generate change from serial work * style(tools): wrap serial lines for golines
This commit is contained in:
@@ -128,6 +128,9 @@ func registerSharedTools(
|
||||
if cfg.Tools.IsToolEnabled("spi") {
|
||||
agent.Tools.Register(tools.NewSPITool())
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("serial") {
|
||||
agent.Tools.Register(tools.NewSerialTool())
|
||||
}
|
||||
|
||||
// Message tool
|
||||
if cfg.Tools.IsToolEnabled("message") {
|
||||
|
||||
@@ -823,6 +823,7 @@ type ToolsConfig struct {
|
||||
ListDir ToolConfig `json:"list_dir" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
|
||||
Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
Serial ToolConfig `json:"serial" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SERIAL_"`
|
||||
SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"`
|
||||
Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
@@ -1548,6 +1549,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
return t.Message.Enabled
|
||||
case "read_file":
|
||||
return t.ReadFile.Enabled
|
||||
case "serial":
|
||||
return t.Serial.Enabled
|
||||
case "spawn":
|
||||
return t.Spawn.Enabled
|
||||
case "spawn_status":
|
||||
|
||||
@@ -435,6 +435,9 @@ func DefaultConfig() *Config {
|
||||
Mode: ReadFileModeBytes,
|
||||
MaxReadFileSize: 64 * 1024, // 64KB
|
||||
},
|
||||
Serial: ToolConfig{
|
||||
Enabled: false, // Hardware tool - requires host serial ports
|
||||
},
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
|
||||
@@ -9,6 +9,9 @@ func TestFacadeConstructorsRemainAvailable(t *testing.T) {
|
||||
if NewSPITool() == nil {
|
||||
t.Fatal("NewSPITool should return a tool")
|
||||
}
|
||||
if NewSerialTool() == nil {
|
||||
t.Fatal("NewSerialTool should return a tool")
|
||||
}
|
||||
if NewMessageTool() == nil {
|
||||
t.Fatal("NewMessageTool should return a tool")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,453 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSerialBaud = 115200
|
||||
defaultSerialDataBits = 8
|
||||
defaultSerialStopBits = 1
|
||||
defaultSerialTimeoutMS = 1000
|
||||
maxSerialPayloadBytes = 4096
|
||||
maxSerialReadBytes = 4096
|
||||
serialPollInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
unixSerialPortPattern = regexp.MustCompile(
|
||||
`^(?:/dev/)?(?:ttyS\d+|ttyUSB\d+|ttyACM\d+|ttyAMA\d+|rfcomm\d+|tty\.[A-Za-z0-9._-]+|cu\.[A-Za-z0-9._-]+)$`,
|
||||
)
|
||||
windowsSerialPortPattern = regexp.MustCompile(`^(?:\\\\\.\\)?COM[1-9]\d*$`)
|
||||
unixSerialBaudRates = map[int]struct{}{
|
||||
50: {}, 75: {}, 110: {}, 134: {}, 150: {}, 200: {}, 300: {}, 600: {}, 1200: {}, 1800: {},
|
||||
2400: {}, 4800: {}, 9600: {}, 19200: {}, 38400: {}, 57600: {}, 115200: {}, 230400: {},
|
||||
}
|
||||
)
|
||||
|
||||
type SerialTool struct{}
|
||||
|
||||
type serialPortInfo struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type serialConfig struct {
|
||||
Port string
|
||||
Baud int
|
||||
DataBits int
|
||||
Parity string
|
||||
StopBits int
|
||||
}
|
||||
|
||||
func NewSerialTool() *SerialTool {
|
||||
return &SerialTool{}
|
||||
}
|
||||
|
||||
func (t *SerialTool) Name() string {
|
||||
return "serial"
|
||||
}
|
||||
|
||||
func (t *SerialTool) Description() string {
|
||||
return "Interact with host serial ports. Actions: list (enumerate ports), read (receive bytes), write (send bytes with explicit confirmation). Supports Linux, macOS, and Windows."
|
||||
}
|
||||
|
||||
func (t *SerialTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"list", "read", "write"},
|
||||
"description": "Action to perform: list available serial ports, read bytes from a port, or write bytes to a port.",
|
||||
},
|
||||
"port": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Serial port path or name, for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3. Required for read/write.",
|
||||
},
|
||||
"baud": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Baud rate. Default: 115200. Linux/macOS currently support standard termios rates up to 230400; Windows accepts configured rates up to 4000000.",
|
||||
},
|
||||
"data_bits": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Data bits. Supported values: 5, 6, 7, 8. Default: 8.",
|
||||
},
|
||||
"parity": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"none", "even", "odd"},
|
||||
"description": "Parity mode. Default: none.",
|
||||
},
|
||||
"stop_bits": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Stop bits. Supported values: 1, 2. Default: 1.",
|
||||
},
|
||||
"timeout_ms": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Read/write timeout in milliseconds. Default: 1000.",
|
||||
},
|
||||
"length": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Number of bytes to read. Required for read. Range: 1-4096.",
|
||||
},
|
||||
"data": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{"type": "integer"},
|
||||
"description": "Bytes to write, each in range 0-255. Required for write unless text is provided.",
|
||||
},
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "UTF-8 text to write. Required for write if data is omitted.",
|
||||
},
|
||||
"confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Must be true for write operations.",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SerialTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok || strings.TrimSpace(action) == "" {
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
return t.list()
|
||||
case "read":
|
||||
return t.read(ctx, args)
|
||||
case "write":
|
||||
return t.write(ctx, args)
|
||||
default:
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, read, write)", action))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SerialTool) list() *ToolResult {
|
||||
ports, err := serialListPorts()
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to list serial ports: %v", err))
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
return SilentResult("No serial ports found on this host.")
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"ports": ports,
|
||||
"count": len(ports),
|
||||
}, "", " ")
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
|
||||
func (t *SerialTool) read(ctx context.Context, args map[string]any) *ToolResult {
|
||||
cfg, errResult := parseSerialConfig(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
length := 0
|
||||
if v, ok := args["length"].(float64); ok {
|
||||
length = int(v)
|
||||
}
|
||||
if length < 1 || length > maxSerialReadBytes {
|
||||
return ErrorResult(fmt.Sprintf("length is required for read (1-%d)", maxSerialReadBytes))
|
||||
}
|
||||
|
||||
timeout, errResult := parseSerialTimeout(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
data, err := serialRead(ctx, cfg, length, timeout)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("serial read failed on %s: %v", cfg.Port, err))
|
||||
}
|
||||
|
||||
return SilentResult(formatSerialPayload("read", cfg, data, timeout))
|
||||
}
|
||||
|
||||
func (t *SerialTool) write(ctx context.Context, args map[string]any) *ToolResult {
|
||||
confirm, _ := args["confirm"].(bool)
|
||||
if !confirm {
|
||||
return ErrorResult(
|
||||
"write operations require confirm: true. Please confirm with the user before sending bytes to a serial device.",
|
||||
)
|
||||
}
|
||||
|
||||
cfg, errResult := parseSerialConfig(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
timeout, errResult := parseSerialTimeout(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
payload, errResult := parseSerialWritePayload(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
written, err := serialWrite(ctx, cfg, payload, timeout)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("serial write failed on %s: %v", cfg.Port, err))
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"action": "write",
|
||||
"port": cfg.Port,
|
||||
"baud": cfg.Baud,
|
||||
"data_bits": cfg.DataBits,
|
||||
"parity": cfg.Parity,
|
||||
"stop_bits": cfg.StopBits,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
"written": written,
|
||||
"payload": serialPayloadSummary(payload),
|
||||
}, "", " ")
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
|
||||
func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) {
|
||||
port, ok := args["port"].(string)
|
||||
port = strings.TrimSpace(port)
|
||||
if !ok || port == "" {
|
||||
return serialConfig{}, ErrorResult(
|
||||
"port is required (for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3)",
|
||||
)
|
||||
}
|
||||
|
||||
normalizedPort, err := normalizeSerialPort(port)
|
||||
if err != nil {
|
||||
return serialConfig{}, ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
cfg := serialConfig{
|
||||
Port: normalizedPort,
|
||||
Baud: defaultSerialBaud,
|
||||
DataBits: defaultSerialDataBits,
|
||||
Parity: "none",
|
||||
StopBits: defaultSerialStopBits,
|
||||
}
|
||||
|
||||
if v, ok := args["baud"].(float64); ok {
|
||||
cfg.Baud = int(v)
|
||||
}
|
||||
if err := validateSerialBaud(cfg.Baud); err != nil {
|
||||
return serialConfig{}, ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if v, ok := args["data_bits"].(float64); ok {
|
||||
cfg.DataBits = int(v)
|
||||
}
|
||||
switch cfg.DataBits {
|
||||
case 5, 6, 7, 8:
|
||||
default:
|
||||
return serialConfig{}, ErrorResult("data_bits must be one of 5, 6, 7, or 8")
|
||||
}
|
||||
|
||||
if v, ok := args["parity"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
cfg.Parity = strings.ToLower(strings.TrimSpace(v))
|
||||
}
|
||||
switch cfg.Parity {
|
||||
case "none", "even", "odd":
|
||||
default:
|
||||
return serialConfig{}, ErrorResult(`parity must be one of "none", "even", or "odd"`)
|
||||
}
|
||||
|
||||
if v, ok := args["stop_bits"].(float64); ok {
|
||||
cfg.StopBits = int(v)
|
||||
}
|
||||
if cfg.StopBits != 1 && cfg.StopBits != 2 {
|
||||
return serialConfig{}, ErrorResult("stop_bits must be 1 or 2")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func parseSerialTimeout(args map[string]any) (time.Duration, *ToolResult) {
|
||||
timeoutMS := defaultSerialTimeoutMS
|
||||
if v, ok := args["timeout_ms"].(float64); ok {
|
||||
timeoutMS = int(v)
|
||||
}
|
||||
if timeoutMS < 1 || timeoutMS > 60000 {
|
||||
return 0, ErrorResult("timeout_ms must be between 1 and 60000")
|
||||
}
|
||||
return time.Duration(timeoutMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func parseSerialWritePayload(args map[string]any) ([]byte, *ToolResult) {
|
||||
if text, ok := args["text"].(string); ok && text != "" {
|
||||
if !utf8.ValidString(text) {
|
||||
return nil, ErrorResult("text must be valid UTF-8")
|
||||
}
|
||||
if len(text) > maxSerialPayloadBytes {
|
||||
return nil, ErrorResult(fmt.Sprintf("text payload too large: maximum %d bytes", maxSerialPayloadBytes))
|
||||
}
|
||||
return []byte(text), nil
|
||||
}
|
||||
|
||||
dataRaw, ok := args["data"].([]any)
|
||||
if !ok || len(dataRaw) == 0 {
|
||||
return nil, ErrorResult("write requires either text or data")
|
||||
}
|
||||
if len(dataRaw) > maxSerialPayloadBytes {
|
||||
return nil, ErrorResult(fmt.Sprintf("data too long: maximum %d bytes", maxSerialPayloadBytes))
|
||||
}
|
||||
|
||||
data := make([]byte, len(dataRaw))
|
||||
for i, v := range dataRaw {
|
||||
f, ok := v.(float64)
|
||||
if !ok {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i))
|
||||
}
|
||||
if f != math.Trunc(f) {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] is not an integer byte value", i))
|
||||
}
|
||||
b := int(f)
|
||||
if b < 0 || b > 255 {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b))
|
||||
}
|
||||
data[i] = byte(b)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func formatSerialPayload(action string, cfg serialConfig, data []byte, timeout time.Duration) string {
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"action": action,
|
||||
"port": cfg.Port,
|
||||
"baud": cfg.Baud,
|
||||
"data_bits": cfg.DataBits,
|
||||
"parity": cfg.Parity,
|
||||
"stop_bits": cfg.StopBits,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
"payload": serialPayloadSummary(data),
|
||||
}, "", " ")
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func serialPayloadSummary(data []byte) map[string]any {
|
||||
hexValues := make([]string, len(data))
|
||||
intValues := make([]int, len(data))
|
||||
for i, b := range data {
|
||||
hexValues[i] = fmt.Sprintf("0x%02x", b)
|
||||
intValues[i] = int(b)
|
||||
}
|
||||
|
||||
summary := map[string]any{
|
||||
"length": len(data),
|
||||
"bytes": intValues,
|
||||
"hex": hexValues,
|
||||
}
|
||||
if utf8.Valid(data) {
|
||||
summary["text"] = string(data)
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func normalizeSerialPort(port string) (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return normalizeWindowsSerialPath(port)
|
||||
case "linux", "darwin":
|
||||
return normalizeUnixSerialPath(port)
|
||||
default:
|
||||
if normalized, err := normalizeUnixSerialPath(port); err == nil {
|
||||
return normalized, nil
|
||||
}
|
||||
return normalizeWindowsSerialPath(port)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeUnixSerialPath(port string) (string, error) {
|
||||
trimmed := strings.TrimSpace(port)
|
||||
if !unixSerialPortPattern.MatchString(trimmed) {
|
||||
return "", fmt.Errorf(
|
||||
"invalid serial port: expected a safe Unix device name such as /dev/ttyUSB0 or /dev/cu.usbserial-0001",
|
||||
)
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "/dev/") {
|
||||
return trimmed, nil
|
||||
}
|
||||
return "/dev/" + trimmed, nil
|
||||
}
|
||||
|
||||
func normalizeWindowsSerialPath(port string) (string, error) {
|
||||
trimmed := strings.ToUpper(strings.TrimSpace(port))
|
||||
if !windowsSerialPortPattern.MatchString(trimmed) {
|
||||
return "", fmt.Errorf("invalid serial port: expected a COM port such as COM3")
|
||||
}
|
||||
if strings.HasPrefix(trimmed, `\\.\`) {
|
||||
return trimmed, nil
|
||||
}
|
||||
return `\\.\` + trimmed, nil
|
||||
}
|
||||
|
||||
func validateSerialBaud(baud int) error {
|
||||
if baud < 50 || baud > 4000000 {
|
||||
return fmt.Errorf("baud must be between 50 and 4000000")
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "darwin":
|
||||
if _, ok := unixSerialBaudRates[baud]; !ok {
|
||||
return fmt.Errorf("unsupported baud rate on this platform: %d (supported up to 230400)", baud)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serialContextErr(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func serialWriteAll(
|
||||
ctx context.Context,
|
||||
data []byte,
|
||||
timeout time.Duration,
|
||||
now func() time.Time,
|
||||
write func([]byte) (int, error),
|
||||
) (int, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total := 0
|
||||
deadline := now().Add(timeout)
|
||||
for total < len(data) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return total, err
|
||||
}
|
||||
if deadline.Sub(now()) <= 0 {
|
||||
return total, fmt.Errorf("timeout while writing serial data")
|
||||
}
|
||||
|
||||
n, err := write(data[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
//go:build darwin
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
func serialGetTermios(fd int) (*unix.Termios, error) {
|
||||
return unix.IoctlGetTermios(fd, unix.TIOCGETA)
|
||||
}
|
||||
|
||||
func serialSetSpeed(tio *unix.Termios, speed uint32) error {
|
||||
tio.Ispeed = uint64(speed)
|
||||
tio.Ospeed = uint64(speed)
|
||||
return nil
|
||||
}
|
||||
|
||||
func serialSetTermios(fd int, tio *unix.Termios) error {
|
||||
return unix.IoctlSetTermios(fd, unix.TIOCSETA, tio)
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
//go:build linux
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
func serialGetTermios(fd int) (*unix.Termios, error) {
|
||||
return unix.IoctlGetTermios(fd, unix.TCGETS)
|
||||
}
|
||||
|
||||
func serialSetSpeed(tio *unix.Termios, speed uint32) error {
|
||||
tio.Ispeed = speed
|
||||
tio.Ospeed = speed
|
||||
return nil
|
||||
}
|
||||
|
||||
func serialSetTermios(fd int, tio *unix.Termios) error {
|
||||
return unix.IoctlSetTermios(fd, unix.TCSETS, tio)
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
//go:build !linux && !darwin && !windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func serialListPorts() ([]serialPortInfo, error) {
|
||||
return nil, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
return nil, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
return 0, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
//go:build !linux && !darwin && !windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSerialListPortsUnsupportedPlatform(t *testing.T) {
|
||||
_, err := serialListPorts()
|
||||
if err == nil {
|
||||
t.Fatal("expected unsupported platform error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not supported") {
|
||||
t.Fatalf("serialListPorts() error = %v, want unsupported platform message", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseSerialConfig(t *testing.T) {
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
cfg, errResult := parseSerialConfig(map[string]any{
|
||||
"port": port,
|
||||
"baud": float64(9600),
|
||||
"data_bits": float64(7),
|
||||
"parity": "even",
|
||||
"stop_bits": float64(2),
|
||||
})
|
||||
if errResult != nil {
|
||||
t.Fatalf("parseSerialConfig() unexpected error = %v", errResult.ForLLM)
|
||||
}
|
||||
|
||||
wantPort := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
wantPort = `\\.\COM3`
|
||||
}
|
||||
if cfg.Port != wantPort || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 {
|
||||
t.Fatalf("parseSerialConfig() = %#v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsInvalidParity(t *testing.T) {
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": port,
|
||||
"parity": "mark",
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected invalid parity to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsUnsupportedUnixBaud(t *testing.T) {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
|
||||
t.Skip("Unix baud validation only applies on Unix platforms")
|
||||
}
|
||||
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": "/dev/ttyUSB0",
|
||||
"baud": float64(460800),
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected unsupported Unix baud rate to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialWritePayloadRejectsFractionalBytes(t *testing.T) {
|
||||
_, errResult := parseSerialWritePayload(map[string]any{
|
||||
"data": []any{65.9},
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected fractional byte value to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSerialBaud(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baud int
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "default-supported", baud: 115200},
|
||||
{name: "max-unix-supported", baud: 230400},
|
||||
{name: "too-low", baud: 49, wantErr: true},
|
||||
{name: "too-high", baud: 4000001, wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateSerialBaud(tt.baud)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("validateSerialBaud(%d) error = %v, wantErr %v", tt.baud, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialReadCanceledBeforeOpen(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, err := serialRead(
|
||||
ctx,
|
||||
serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1},
|
||||
1,
|
||||
time.Second,
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
|
||||
t.Fatalf("serialRead() error = %v, want context canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteCanceledBeforeOpen(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, err := serialWrite(
|
||||
ctx,
|
||||
serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1},
|
||||
[]byte("AT"),
|
||||
time.Second,
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
|
||||
t.Fatalf("serialWrite() error = %v, want context canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsUnsafePortPaths(t *testing.T) {
|
||||
tests := []string{
|
||||
"../../../etc/passwd",
|
||||
"/etc/passwd",
|
||||
`C:\temp\device.txt`,
|
||||
`\\.\C:\temp\device.txt`,
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) {
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": port,
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatalf("expected unsafe port %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUnixSerialPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
port string
|
||||
want string
|
||||
}{
|
||||
{port: "ttyUSB0", want: "/dev/ttyUSB0"},
|
||||
{port: "/dev/ttyACM0", want: "/dev/ttyACM0"},
|
||||
{port: "/dev/cu.usbserial-0001", want: "/dev/cu.usbserial-0001"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := normalizeUnixSerialPath(tt.port)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeUnixSerialPath(%q) unexpected error = %v", tt.port, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("normalizeUnixSerialPath(%q) = %q, want %q", tt.port, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUnixSerialPathRejectsInvalidNames(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"ttyUSB0/../../passwd",
|
||||
"/dev/../../etc/passwd",
|
||||
"/tmp/ttyUSB0",
|
||||
"ttyUSB",
|
||||
"COM3",
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) {
|
||||
if _, err := normalizeUnixSerialPath(port); err == nil {
|
||||
t.Fatalf("expected %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeWindowsSerialPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
port string
|
||||
want string
|
||||
}{
|
||||
{port: "COM3", want: `\\.\COM3`},
|
||||
{port: "com12", want: `\\.\COM12`},
|
||||
{port: `\\.\COM7`, want: `\\.\COM7`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := normalizeWindowsSerialPath(tt.port)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeWindowsSerialPath(%q) unexpected error = %v", tt.port, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("normalizeWindowsSerialPath(%q) = %q, want %q", tt.port, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeWindowsSerialPathRejectsInvalidNames(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"COM0",
|
||||
"COM",
|
||||
"/dev/ttyUSB0",
|
||||
`C:\temp\device.txt`,
|
||||
`\\.\C:\temp\device.txt`,
|
||||
`\\server\share\COM3`,
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(strings.ReplaceAll(port, `\`, "_"), "/", "_"), func(t *testing.T) {
|
||||
if _, err := normalizeWindowsSerialPath(port); err == nil {
|
||||
t.Fatalf("expected %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialTimeout(t *testing.T) {
|
||||
timeout, errResult := parseSerialTimeout(map[string]any{
|
||||
"timeout_ms": float64(2500),
|
||||
})
|
||||
if errResult != nil {
|
||||
t.Fatalf("parseSerialTimeout() unexpected error = %v", errResult.ForLLM)
|
||||
}
|
||||
if timeout != 2500*time.Millisecond {
|
||||
t.Fatalf("timeout = %v, want 2500ms", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialWritePayloadSupportsText(t *testing.T) {
|
||||
data, errResult := parseSerialWritePayload(map[string]any{
|
||||
"text": "AT\r\n",
|
||||
})
|
||||
if errResult != nil {
|
||||
t.Fatalf("parseSerialWritePayload() unexpected error = %v", errResult.ForLLM)
|
||||
}
|
||||
if string(data) != "AT\r\n" {
|
||||
t.Fatalf("payload = %q, want %q", string(data), "AT\r\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialWritePayloadRejectsOutOfRangeByte(t *testing.T) {
|
||||
_, errResult := parseSerialWritePayload(map[string]any{
|
||||
"data": []any{float64(256)},
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected payload validation failure")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
unixSerialNow = time.Now
|
||||
unixSerialOpenPort = openAndConfigureSerialPort
|
||||
unixSerialClosePort = unix.Close
|
||||
unixSerialPollRead = pollRead
|
||||
unixSerialPollWrite = pollWrite
|
||||
)
|
||||
|
||||
func serialListPorts() ([]serialPortInfo, error) {
|
||||
patterns := []string{
|
||||
"/dev/ttyS*",
|
||||
"/dev/ttyUSB*",
|
||||
"/dev/ttyACM*",
|
||||
"/dev/ttyAMA*",
|
||||
"/dev/rfcomm*",
|
||||
"/dev/tty.*",
|
||||
"/dev/cu.*",
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
ports := make([]serialPortInfo, 0)
|
||||
for _, pattern := range patterns {
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, match := range matches {
|
||||
if _, ok := seen[match]; ok {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(match)
|
||||
if err != nil || info.IsDir() {
|
||||
continue
|
||||
}
|
||||
seen[match] = struct{}{}
|
||||
ports = append(ports, serialPortInfo{
|
||||
Name: filepath.Base(match),
|
||||
Path: match,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(ports, func(i, j int) bool {
|
||||
return ports[i].Path < ports[j].Path
|
||||
})
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fd, err := unixSerialOpenPort(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer unixSerialClosePort(fd)
|
||||
|
||||
buf := make([]byte, length)
|
||||
total := 0
|
||||
deadline := unixSerialNow().Add(timeout)
|
||||
|
||||
for total < length {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remaining := deadline.Sub(unixSerialNow())
|
||||
if remaining <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := unixSerialPollRead(fd, buf[total:], minSerialPollTimeout(remaining))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
return buf[:total], nil
|
||||
}
|
||||
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
fd, err := unixSerialOpenPort(cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer unixSerialClosePort(fd)
|
||||
|
||||
total := 0
|
||||
deadline := unixSerialNow().Add(timeout)
|
||||
for total < len(data) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
remaining := deadline.Sub(unixSerialNow())
|
||||
if remaining <= 0 {
|
||||
return total, fmt.Errorf("timeout while writing serial data")
|
||||
}
|
||||
|
||||
n, err := unixSerialPollWrite(fd, data[total:], minSerialPollTimeout(remaining))
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func openAndConfigureSerialPort(cfg serialConfig) (int, error) {
|
||||
fd, err := unix.Open(cfg.Port, unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if err := unix.SetNonblock(fd, false); err != nil {
|
||||
unix.Close(fd)
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if err := configureUnixSerialPort(fd, cfg); err != nil {
|
||||
unix.Close(fd)
|
||||
return -1, err
|
||||
}
|
||||
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func configureUnixSerialPort(fd int, cfg serialConfig) error {
|
||||
tio, err := serialGetTermios(fd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tio.Iflag = 0
|
||||
tio.Oflag = 0
|
||||
tio.Lflag = 0
|
||||
tio.Cflag = unix.CREAD | unix.CLOCAL
|
||||
tio.Cc[unix.VMIN] = 0
|
||||
tio.Cc[unix.VTIME] = 0
|
||||
|
||||
switch cfg.DataBits {
|
||||
case 5:
|
||||
tio.Cflag |= unix.CS5
|
||||
case 6:
|
||||
tio.Cflag |= unix.CS6
|
||||
case 7:
|
||||
tio.Cflag |= unix.CS7
|
||||
default:
|
||||
tio.Cflag |= unix.CS8
|
||||
}
|
||||
|
||||
switch cfg.Parity {
|
||||
case "even":
|
||||
tio.Cflag |= unix.PARENB
|
||||
case "odd":
|
||||
tio.Cflag |= unix.PARENB | unix.PARODD
|
||||
}
|
||||
|
||||
if cfg.StopBits == 2 {
|
||||
tio.Cflag |= unix.CSTOPB
|
||||
}
|
||||
|
||||
speed, err := serialBaudToUnix(cfg.Baud)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := serialSetSpeed(tio, speed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return serialSetTermios(fd, tio)
|
||||
}
|
||||
|
||||
func serialBaudToUnix(baud int) (uint32, error) {
|
||||
switch baud {
|
||||
case 50:
|
||||
return unix.B50, nil
|
||||
case 75:
|
||||
return unix.B75, nil
|
||||
case 110:
|
||||
return unix.B110, nil
|
||||
case 134:
|
||||
return unix.B134, nil
|
||||
case 150:
|
||||
return unix.B150, nil
|
||||
case 200:
|
||||
return unix.B200, nil
|
||||
case 300:
|
||||
return unix.B300, nil
|
||||
case 600:
|
||||
return unix.B600, nil
|
||||
case 1200:
|
||||
return unix.B1200, nil
|
||||
case 1800:
|
||||
return unix.B1800, nil
|
||||
case 2400:
|
||||
return unix.B2400, nil
|
||||
case 4800:
|
||||
return unix.B4800, nil
|
||||
case 9600:
|
||||
return unix.B9600, nil
|
||||
case 19200:
|
||||
return unix.B19200, nil
|
||||
case 38400:
|
||||
return unix.B38400, nil
|
||||
case 57600:
|
||||
return unix.B57600, nil
|
||||
case 115200:
|
||||
return unix.B115200, nil
|
||||
case 230400:
|
||||
return unix.B230400, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported baud rate on this platform: %d", baud)
|
||||
}
|
||||
}
|
||||
|
||||
func pollRead(fd int, dst []byte, timeout time.Duration) (int, error) {
|
||||
pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}}
|
||||
n, err := unix.Poll(pfd, durationToPollTimeout(timeout))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return unix.Read(fd, dst)
|
||||
}
|
||||
|
||||
func pollWrite(fd int, src []byte, timeout time.Duration) (int, error) {
|
||||
pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLOUT}}
|
||||
n, err := unix.Poll(pfd, durationToPollTimeout(timeout))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return unix.Write(fd, src)
|
||||
}
|
||||
|
||||
func durationToPollTimeout(timeout time.Duration) int {
|
||||
if timeout <= 0 {
|
||||
return 0
|
||||
}
|
||||
ms := int(timeout / time.Millisecond)
|
||||
if ms == 0 {
|
||||
return 1
|
||||
}
|
||||
return ms
|
||||
}
|
||||
|
||||
func minSerialPollTimeout(timeout time.Duration) time.Duration {
|
||||
if timeout > serialPollInterval {
|
||||
return serialPollInterval
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func stubUnixSerialIO(t *testing.T, now *time.Time) {
|
||||
t.Helper()
|
||||
|
||||
prevNow := unixSerialNow
|
||||
prevOpen := unixSerialOpenPort
|
||||
prevClose := unixSerialClosePort
|
||||
prevPollRead := unixSerialPollRead
|
||||
prevPollWrite := unixSerialPollWrite
|
||||
|
||||
unixSerialNow = func() time.Time {
|
||||
return *now
|
||||
}
|
||||
unixSerialOpenPort = func(cfg serialConfig) (int, error) {
|
||||
return 42, nil
|
||||
}
|
||||
unixSerialClosePort = func(fd int) error {
|
||||
return nil
|
||||
}
|
||||
unixSerialPollRead = prevPollRead
|
||||
unixSerialPollWrite = prevPollWrite
|
||||
|
||||
t.Cleanup(func() {
|
||||
unixSerialNow = prevNow
|
||||
unixSerialOpenPort = prevOpen
|
||||
unixSerialClosePort = prevClose
|
||||
unixSerialPollRead = prevPollRead
|
||||
unixSerialPollWrite = prevPollWrite
|
||||
})
|
||||
}
|
||||
|
||||
func TestSerialReadWaitsPastEmptyPollsUntilDeadline(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
pollCalls := 0
|
||||
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
if timeout > serialPollInterval {
|
||||
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
|
||||
}
|
||||
now = now.Add(timeout)
|
||||
if pollCalls < 4 {
|
||||
return 0, nil
|
||||
}
|
||||
return copy(dst, []byte("OK")), nil
|
||||
}
|
||||
|
||||
got, err := serialRead(context.Background(), serialConfig{}, 2, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("serialRead() error = %v", err)
|
||||
}
|
||||
if string(got) != "OK" {
|
||||
t.Fatalf("serialRead() = %q, want %q", got, "OK")
|
||||
}
|
||||
if pollCalls != 4 {
|
||||
t.Fatalf("poll calls = %d, want 4", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialReadReturnsPromptlyOnContextCancelBetweenPolls(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pollCalls := 0
|
||||
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
now = now.Add(timeout)
|
||||
cancel()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
_, err := serialRead(ctx, serialConfig{}, 1, time.Second)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("serialRead() error = %v, want context canceled", err)
|
||||
}
|
||||
if pollCalls != 1 {
|
||||
t.Fatalf("poll calls = %d, want 1", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteWaitsPastEmptyPollsUntilReady(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
pollCalls := 0
|
||||
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
if timeout > serialPollInterval {
|
||||
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
|
||||
}
|
||||
now = now.Add(timeout)
|
||||
switch pollCalls {
|
||||
case 1, 2:
|
||||
return 0, nil
|
||||
default:
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
written, err := serialWrite(context.Background(), serialConfig{}, []byte("OK"), 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("serialWrite() error = %v", err)
|
||||
}
|
||||
if written != 2 {
|
||||
t.Fatalf("serialWrite() wrote %d bytes, want 2", written)
|
||||
}
|
||||
if pollCalls != 4 {
|
||||
t.Fatalf("poll calls = %d, want 4", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteTimesOutAfterRepeatedEmptyPolls(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
|
||||
now = now.Add(timeout)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
written, err := serialWrite(context.Background(), serialConfig{}, []byte("A"), 250*time.Millisecond)
|
||||
if err == nil || err.Error() != "timeout while writing serial data" {
|
||||
t.Fatalf("serialWrite() error = %v, want timeout", err)
|
||||
}
|
||||
if written != 0 {
|
||||
t.Fatalf("serialWrite() wrote %d bytes, want 0", written)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
//go:build windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
procGetCommState = kernel32.NewProc("GetCommState")
|
||||
procSetCommState = kernel32.NewProc("SetCommState")
|
||||
procSetCommTimeouts = kernel32.NewProc("SetCommTimeouts")
|
||||
procPurgeComm = kernel32.NewProc("PurgeComm")
|
||||
)
|
||||
|
||||
const (
|
||||
purgeTxClear = 0x0004
|
||||
purgeRxClear = 0x0008
|
||||
|
||||
dcbFlagBinary = 0x00000001
|
||||
dcbFlagParity = 0x00000002
|
||||
dcbFlagOutxCtsFlow = 0x00000004
|
||||
dcbFlagOutxDsrFlow = 0x00000008
|
||||
dcbFlagDtrControlMask = 0x00000030
|
||||
dcbFlagDsrSensitivity = 0x00000040
|
||||
dcbFlagTXContinueOnXoff = 0x00000080
|
||||
dcbFlagOutX = 0x00000100
|
||||
dcbFlagInX = 0x00000200
|
||||
dcbFlagRtsControlMask = 0x00003000
|
||||
)
|
||||
|
||||
type dcb struct {
|
||||
DCBlength uint32
|
||||
BaudRate uint32
|
||||
Flags uint32
|
||||
Reserved uint16
|
||||
XonLim uint16
|
||||
XoffLim uint16
|
||||
ByteSize byte
|
||||
Parity byte
|
||||
StopBits byte
|
||||
XonChar byte
|
||||
XoffChar byte
|
||||
ErrorChar byte
|
||||
EofChar byte
|
||||
EvtChar byte
|
||||
wReserved1 uint16
|
||||
}
|
||||
|
||||
type commTimeouts struct {
|
||||
ReadIntervalTimeout uint32
|
||||
ReadTotalTimeoutMultiplier uint32
|
||||
ReadTotalTimeoutConstant uint32
|
||||
WriteTotalTimeoutMultiplier uint32
|
||||
WriteTotalTimeoutConstant uint32
|
||||
}
|
||||
|
||||
func serialListPorts() ([]serialPortInfo, error) {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, `HARDWARE\DEVICEMAP\SERIALCOMM`, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
names, err := key.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ports := make([]serialPortInfo, 0, len(names))
|
||||
seen := make(map[string]struct{})
|
||||
for _, name := range names {
|
||||
value, _, err := key.GetStringValue(name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
portName := strings.TrimSpace(value)
|
||||
if portName == "" {
|
||||
continue
|
||||
}
|
||||
normalized := strings.ToUpper(portName)
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
ports = append(ports, serialPortInfo{
|
||||
Name: normalized,
|
||||
Path: normalized,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(ports, func(i, j int) bool {
|
||||
return ports[i].Path < ports[j].Path
|
||||
})
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer windows.CloseHandle(handle)
|
||||
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := make([]byte, length)
|
||||
var read uint32
|
||||
// Synchronous serial I/O on Windows cannot be interrupted once the syscall starts.
|
||||
// COMMTIMEOUTS bounds how long turn cancellation may take to surface.
|
||||
if err := windows.ReadFile(handle, buf, &read, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf[:read], nil
|
||||
}
|
||||
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer windows.CloseHandle(handle)
|
||||
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return serialWriteAll(ctx, data, timeout, time.Now, func(chunk []byte) (int, error) {
|
||||
var written uint32
|
||||
// Like ReadFile above, this synchronous WriteFile call relies on COMMTIMEOUTS
|
||||
// rather than context preemption once the syscall is in flight.
|
||||
if err := windows.WriteFile(handle, chunk, &written, nil); err != nil {
|
||||
return int(written), err
|
||||
}
|
||||
return int(written), nil
|
||||
})
|
||||
}
|
||||
|
||||
func openAndConfigureWindowsSerial(cfg serialConfig, timeout time.Duration) (windows.Handle, error) {
|
||||
handle, err := windows.CreateFile(
|
||||
windows.StringToUTF16Ptr(cfg.Port),
|
||||
windows.GENERIC_READ|windows.GENERIC_WRITE,
|
||||
0,
|
||||
nil,
|
||||
windows.OPEN_EXISTING,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := configureWindowsSerialPort(handle, cfg, timeout); err != nil {
|
||||
windows.CloseHandle(handle)
|
||||
return 0, err
|
||||
}
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout time.Duration) error {
|
||||
state := dcb{DCBlength: uint32(unsafe.Sizeof(dcb{}))}
|
||||
r1, _, err := procGetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state)))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
state.BaudRate = uint32(cfg.Baud)
|
||||
state.ByteSize = byte(cfg.DataBits)
|
||||
state.Flags = sanitizeWindowsSerialFlags(state.Flags)
|
||||
state.Flags |= dcbFlagBinary
|
||||
|
||||
switch cfg.Parity {
|
||||
case "even":
|
||||
state.Parity = 2
|
||||
state.Flags |= dcbFlagParity
|
||||
case "odd":
|
||||
state.Parity = 1
|
||||
state.Flags |= dcbFlagParity
|
||||
default:
|
||||
state.Parity = 0
|
||||
state.Flags &^= dcbFlagParity
|
||||
}
|
||||
|
||||
switch cfg.StopBits {
|
||||
case 2:
|
||||
state.StopBits = 2
|
||||
default:
|
||||
state.StopBits = 0
|
||||
}
|
||||
|
||||
r1, _, err = procSetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state)))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
timeoutMS := uint32(timeout / time.Millisecond)
|
||||
if timeoutMS == 0 {
|
||||
timeoutMS = 1
|
||||
}
|
||||
timeouts := commTimeouts{
|
||||
ReadIntervalTimeout: timeoutMS,
|
||||
ReadTotalTimeoutConstant: timeoutMS,
|
||||
WriteTotalTimeoutConstant: timeoutMS,
|
||||
ReadTotalTimeoutMultiplier: 0,
|
||||
WriteTotalTimeoutMultiplier: 0,
|
||||
}
|
||||
r1, _, err = procSetCommTimeouts.Call(uintptr(handle), uintptr(unsafe.Pointer(&timeouts)))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
procPurgeComm.Call(uintptr(handle), uintptr(purgeRxClear|purgeTxClear))
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitizeWindowsSerialFlags(flags uint32) uint32 {
|
||||
flags &^= dcbFlagOutxCtsFlow |
|
||||
dcbFlagOutxDsrFlow |
|
||||
dcbFlagDtrControlMask |
|
||||
dcbFlagDsrSensitivity |
|
||||
dcbFlagTXContinueOnXoff |
|
||||
dcbFlagOutX |
|
||||
dcbFlagInX |
|
||||
dcbFlagRtsControlMask
|
||||
return flags
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
//go:build windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeWindowsSerialFlags(t *testing.T) {
|
||||
flags := uint32(
|
||||
dcbFlagBinary |
|
||||
dcbFlagParity |
|
||||
dcbFlagOutxCtsFlow |
|
||||
dcbFlagOutxDsrFlow |
|
||||
dcbFlagDtrControlMask |
|
||||
dcbFlagDsrSensitivity |
|
||||
dcbFlagTXContinueOnXoff |
|
||||
dcbFlagOutX |
|
||||
dcbFlagInX |
|
||||
dcbFlagRtsControlMask,
|
||||
)
|
||||
|
||||
got := sanitizeWindowsSerialFlags(flags)
|
||||
|
||||
if got&dcbFlagBinary == 0 {
|
||||
t.Fatal("sanitizeWindowsSerialFlags() should preserve fBinary")
|
||||
}
|
||||
if got&dcbFlagParity == 0 {
|
||||
t.Fatal("sanitizeWindowsSerialFlags() should preserve fParity")
|
||||
}
|
||||
if got&(dcbFlagOutxCtsFlow|
|
||||
dcbFlagOutxDsrFlow|
|
||||
dcbFlagDtrControlMask|
|
||||
dcbFlagDsrSensitivity|
|
||||
dcbFlagTXContinueOnXoff|
|
||||
dcbFlagOutX|
|
||||
dcbFlagInX|
|
||||
dcbFlagRtsControlMask) != 0 {
|
||||
t.Fatalf("sanitizeWindowsSerialFlags() = %#x, want flow-control bits cleared", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSerialWriteAllRetriesPartialWritesUntilComplete(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
calls := 0
|
||||
|
||||
written, err := serialWriteAll(context.Background(), []byte("PING"), time.Second, func() time.Time {
|
||||
return now
|
||||
}, func(chunk []byte) (int, error) {
|
||||
calls++
|
||||
now = now.Add(100 * time.Millisecond)
|
||||
switch calls {
|
||||
case 1:
|
||||
if string(chunk) != "PING" {
|
||||
t.Fatalf("first chunk = %q, want %q", chunk, "PING")
|
||||
}
|
||||
return 2, nil
|
||||
case 2:
|
||||
if string(chunk) != "NG" {
|
||||
t.Fatalf("second chunk = %q, want %q", chunk, "NG")
|
||||
}
|
||||
return 2, nil
|
||||
default:
|
||||
t.Fatalf("unexpected extra write call %d", calls)
|
||||
return 0, nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("serialWriteAll() error = %v", err)
|
||||
}
|
||||
if written != 4 {
|
||||
t.Fatalf("serialWriteAll() wrote %d bytes, want 4", written)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteAllTimesOutAfterZeroByteWrites(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
calls := 0
|
||||
|
||||
written, err := serialWriteAll(context.Background(), []byte("A"), 250*time.Millisecond, func() time.Time {
|
||||
return now
|
||||
}, func(chunk []byte) (int, error) {
|
||||
calls++
|
||||
now = now.Add(100 * time.Millisecond)
|
||||
return 0, nil
|
||||
})
|
||||
if err == nil || err.Error() != "timeout while writing serial data" {
|
||||
t.Fatalf("serialWriteAll() error = %v, want timeout", err)
|
||||
}
|
||||
if written != 0 {
|
||||
t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written)
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Fatalf("write calls = %d, want 3", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteAllReturnsContextCancellationAfterRetryBoundary(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
calls := 0
|
||||
|
||||
written, err := serialWriteAll(ctx, []byte("A"), time.Second, func() time.Time {
|
||||
return now
|
||||
}, func(chunk []byte) (int, error) {
|
||||
calls++
|
||||
now = now.Add(100 * time.Millisecond)
|
||||
cancel()
|
||||
return 0, nil
|
||||
})
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("serialWriteAll() error = %v, want context canceled", err)
|
||||
}
|
||||
if written != 0 {
|
||||
t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("write calls = %d, want 1", calls)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,9 @@ package tools
|
||||
import hardwaretools "github.com/sipeed/picoclaw/pkg/tools/hardware"
|
||||
|
||||
type (
|
||||
I2CTool = hardwaretools.I2CTool
|
||||
SPITool = hardwaretools.SPITool
|
||||
I2CTool = hardwaretools.I2CTool
|
||||
SerialTool = hardwaretools.SerialTool
|
||||
SPITool = hardwaretools.SPITool
|
||||
)
|
||||
|
||||
func NewI2CTool() *I2CTool {
|
||||
@@ -14,3 +15,7 @@ func NewI2CTool() *I2CTool {
|
||||
func NewSPITool() *SPITool {
|
||||
return hardwaretools.NewSPITool()
|
||||
}
|
||||
|
||||
func NewSerialTool() *SerialTool {
|
||||
return hardwaretools.NewSerialTool()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user