Add cross-platform serial tool support (#2673)

* feat(tools): add cross-platform serial hardware tool

* feat(config): wire serial tool into runtime and dashboard

* hardware/serial: tighten validation and error handling

* hardware/serial: improve unix cancellation and timeout polling

* hardware/serial: improve windows I/O handling

* hardware/serial: fix darwin cross-compilation build

* docs(design): summarize hardware support and serial limits

* build: keep go generate on host during cross builds

* onboard: drop unrelated go generate change from serial work

* style(tools): wrap serial lines for golines
This commit is contained in:
Guoguo
2026-04-28 13:10:32 +08:00
committed by GitHub
24 changed files with 1810 additions and 6 deletions
+16 -3
View File
@@ -171,6 +171,18 @@ ifeq ($(OS),Windows_NT)
EXT=.exe
endif
ifneq ($(strip $(GOOS)),)
PLATFORM:=$(GOOS)
endif
ifneq ($(strip $(GOARCH)),)
ARCH:=$(GOARCH)
endif
ifeq ($(PLATFORM),windows)
EXT=.exe
endif
BINARY_PATH=$(BUILD_DIR)/$(BINARY_NAME)-$(PLATFORM)-$(ARCH)
# Default target
@@ -181,10 +193,11 @@ generate:
@echo "Run generate..."
ifeq ($(OS),Windows_NT)
@$(POWERSHELL) "if (Test-Path -LiteralPath './$(CMD_DIR)/workspace') { Remove-Item -LiteralPath './$(CMD_DIR)/workspace' -Recurse -Force }"
@$(POWERSHELL) "$$env:GOOS=''; $$env:GOARCH=''; $(GO) generate ./..."
else
@rm -r ./$(CMD_DIR)/workspace 2>/dev/null || true
@GOOS=$$($(GO) env GOHOSTOS) GOARCH=$$($(GO) env GOHOSTARCH) $(GO) generate ./...
endif
@$(GO) generate ./...
@echo "Run generate complete"
## build: Build the picoclaw binary for current platform
@@ -196,7 +209,7 @@ ifeq ($(OS),Windows_NT)
@$(POWERSHELL) "Copy-Item -LiteralPath '$(BINARY_PATH)$(EXT)' -Destination '$(BUILD_DIR)/$(BINARY_NAME)$(EXT)' -Force"
else
@mkdir -p $(BUILD_DIR)
@GOARCH=${ARCH} $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BINARY_PATH)$(EXT) ./$(CMD_DIR)
@GOOS=$(PLATFORM) GOARCH=$(ARCH) $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BINARY_PATH)$(EXT) ./$(CMD_DIR)
@echo "Build complete: $(BINARY_PATH)$(EXT)"
@$(LNCMD) $(BINARY_NAME)-$(PLATFORM)-$(ARCH)$(EXT) $(BUILD_DIR)/$(BINARY_NAME)$(EXT)
endif
@@ -211,7 +224,7 @@ ifeq ($(OS),Windows_NT)
@$(POWERSHELL) "Copy-Item -LiteralPath '$(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH)$(EXT)' -Destination '$(BUILD_DIR)/picoclaw-launcher$(EXT)' -Force"
else
@mkdir -p $(BUILD_DIR)
@GOARCH=${ARCH} $(MAKE) -C web build \
@GOOS=$(PLATFORM) GOARCH=$(ARCH) $(MAKE) -C web build \
OUTPUT="$(CURDIR)/$(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH)$(EXT)" \
WEB_GO='$(WEB_GO)' \
GO_BUILD_TAGS='$(GO_BUILD_TAGS)' \
+1 -1
View File
@@ -6,7 +6,7 @@ import (
"github.com/spf13/cobra"
)
//go:generate go run ../../../../scripts/copydir.go "${DOLLAR}{codespace}/workspace" ./workspace
//go:generate go run ../../../../scripts/copydir.go ../../../../workspace ./workspace
//go:embed workspace
var embeddedFiles embed.FS
+3
View File
@@ -437,6 +437,9 @@
"enabled": true,
"mode": "bytes"
},
"serial": {
"enabled": false
},
"send_tts": {
"enabled": false
},
@@ -0,0 +1,91 @@
# 当前硬件支持现状与串口 Tool 方案
## 现状结论
当前项目已有的硬件相关能力主要分为两条线:
1. 设备事件监控
- `pkg/devices` 已实现设备事件服务。
- 当前只有 Linux USB 热插拔事件源 `pkg/devices/sources/usb_linux.go`
- 能力定位是“发现和通知”,不是“总线读写控制”。
2. 硬件控制 Tool
- `pkg/tools/hardware/i2c*.go`I2C Tool,支持 `detect``scan``read``write`
- `pkg/tools/hardware/spi*.go`SPI Tool,支持 `list``transfer``read`
- 这两类 Tool 当前都只在 Linux 主机上启用,直接依赖 `/dev/i2c-*``/dev/spidev*`
因此,项目在“硬件支持能力”上已经具备:
- Linux USB 设备插拔感知
- Linux I2C 总线控制
- Linux SPI 总线控制
但还缺少:
- 串口/UART 控制
- macOS / Windows 下可直接使用的硬件控制 Tool
- 面向统一硬件抽象的跨总线能力模型
## 本次新增
本次新增内建 `serial` Tool,并接入现有 Tool 体系:
- 配置项:`tools.serial.enabled`
- Tool 注册:`pkg/agent/agent_init.go`
- Web 工具页:`/api/tools` 能展示与切换 `serial`
- 前端状态文案:新增 `requires_serial_platform`
## Serial Tool 设计
`serial` 采用无状态调用模型,每次请求都自行打开和关闭端口,避免在 agent 回合之间维护串口会话状态。
支持动作:
- `list`:枚举主机串口
- `read`:从串口读取指定长度字节
- `write`:向串口写入字节或文本
公共参数:
- `port`
- `baud`
- `data_bits`
- `parity`
- `stop_bits`
- `timeout_ms`
当前波特率实现边界:
- Windows 允许配置工具层接受的范围 `50-4000000`
- Linux / macOS 当前仅支持标准 termios 波特率,实际支持到 `230400`
- 因此 `baud` 的跨平台可移植取值应优先使用 `230400` 及以下的常见标准速率
安全约束:
- `write` 必须显式传 `confirm: true`
- 单次读写负载限制为 `4096` 字节
- `port` 只接受白名单串口名:
- Linux / macOS 仅允许 `/dev/tty*``/dev/cu.*` 及对应简写设备名
- Windows 仅允许 `COM\d+``\\.\COM\d+`
- 明确拒绝 `..`、普通文件绝对路径、盘符路径等非串口设备路径,避免路径穿越或误打开任意文件
## 跨平台实现边界
- Linux / macOS
- 基于 `golang.org/x/sys/unix` 和 termios 配置串口参数。
- 当前仅接入标准 termios 波特率映射,最高到 `230400`,尚未扩展 `460800``921600``1000000``2000000` 等更高速率。
- 通过 `/dev/...` 枚举和访问设备。
- Windows
- 基于 `kernel32` 串口 API 配置 `DCB``COMMTIMEOUTS`
- 当前读写仍使用同步 `ReadFile` / `WriteFile`;一旦 syscall 已进入执行,turn context cancellation 不能立即打断,只能等待 `COMMTIMEOUTS` 触发后返回。
- 通过注册表 `HARDWARE\\DEVICEMAP\\SERIALCOMM` 枚举端口。
- 其他平台:
- `serial` Tool 显式返回 unsupported,不做静默降级。
## 后续建议
1. 如果需要持续交互式串口会话,建议再增加 session 型 Tool,而不是让 LLM 反复做短连接轮询。
2. 如果后续要支持 CAN、GPIO、PWM,建议抽出统一的硬件 capability 描述层,而不是继续只靠 Tool 名称区分。
3. 若需要生产级稳定性,建议补真实串口回环测试,至少覆盖 Linux PTY 和 Windows COM 模拟场景。
+3
View File
@@ -128,6 +128,9 @@ func registerSharedTools(
if cfg.Tools.IsToolEnabled("spi") {
agent.Tools.Register(tools.NewSPITool())
}
if cfg.Tools.IsToolEnabled("serial") {
agent.Tools.Register(tools.NewSerialTool())
}
// Message tool
if cfg.Tools.IsToolEnabled("message") {
+3
View File
@@ -823,6 +823,7 @@ type ToolsConfig struct {
ListDir ToolConfig `json:"list_dir" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
Serial ToolConfig `json:"serial" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SERIAL_"`
SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"`
Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
@@ -1548,6 +1549,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
return t.Message.Enabled
case "read_file":
return t.ReadFile.Enabled
case "serial":
return t.Serial.Enabled
case "spawn":
return t.Spawn.Enabled
case "spawn_status":
+3
View File
@@ -435,6 +435,9 @@ func DefaultConfig() *Config {
Mode: ReadFileModeBytes,
MaxReadFileSize: 64 * 1024, // 64KB
},
Serial: ToolConfig{
Enabled: false, // Hardware tool - requires host serial ports
},
Spawn: ToolConfig{
Enabled: true,
},
+3
View File
@@ -9,6 +9,9 @@ func TestFacadeConstructorsRemainAvailable(t *testing.T) {
if NewSPITool() == nil {
t.Fatal("NewSPITool should return a tool")
}
if NewSerialTool() == nil {
t.Fatal("NewSerialTool should return a tool")
}
if NewMessageTool() == nil {
t.Fatal("NewMessageTool should return a tool")
}
+453
View File
@@ -0,0 +1,453 @@
package hardwaretools
import (
"context"
"encoding/json"
"fmt"
"math"
"regexp"
"runtime"
"strings"
"time"
"unicode/utf8"
)
const (
defaultSerialBaud = 115200
defaultSerialDataBits = 8
defaultSerialStopBits = 1
defaultSerialTimeoutMS = 1000
maxSerialPayloadBytes = 4096
maxSerialReadBytes = 4096
serialPollInterval = 100 * time.Millisecond
)
var (
unixSerialPortPattern = regexp.MustCompile(
`^(?:/dev/)?(?:ttyS\d+|ttyUSB\d+|ttyACM\d+|ttyAMA\d+|rfcomm\d+|tty\.[A-Za-z0-9._-]+|cu\.[A-Za-z0-9._-]+)$`,
)
windowsSerialPortPattern = regexp.MustCompile(`^(?:\\\\\.\\)?COM[1-9]\d*$`)
unixSerialBaudRates = map[int]struct{}{
50: {}, 75: {}, 110: {}, 134: {}, 150: {}, 200: {}, 300: {}, 600: {}, 1200: {}, 1800: {},
2400: {}, 4800: {}, 9600: {}, 19200: {}, 38400: {}, 57600: {}, 115200: {}, 230400: {},
}
)
type SerialTool struct{}
type serialPortInfo struct {
Name string `json:"name"`
Path string `json:"path"`
}
type serialConfig struct {
Port string
Baud int
DataBits int
Parity string
StopBits int
}
func NewSerialTool() *SerialTool {
return &SerialTool{}
}
func (t *SerialTool) Name() string {
return "serial"
}
func (t *SerialTool) Description() string {
return "Interact with host serial ports. Actions: list (enumerate ports), read (receive bytes), write (send bytes with explicit confirmation). Supports Linux, macOS, and Windows."
}
func (t *SerialTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"action": map[string]any{
"type": "string",
"enum": []string{"list", "read", "write"},
"description": "Action to perform: list available serial ports, read bytes from a port, or write bytes to a port.",
},
"port": map[string]any{
"type": "string",
"description": "Serial port path or name, for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3. Required for read/write.",
},
"baud": map[string]any{
"type": "integer",
"description": "Baud rate. Default: 115200. Linux/macOS currently support standard termios rates up to 230400; Windows accepts configured rates up to 4000000.",
},
"data_bits": map[string]any{
"type": "integer",
"description": "Data bits. Supported values: 5, 6, 7, 8. Default: 8.",
},
"parity": map[string]any{
"type": "string",
"enum": []string{"none", "even", "odd"},
"description": "Parity mode. Default: none.",
},
"stop_bits": map[string]any{
"type": "integer",
"description": "Stop bits. Supported values: 1, 2. Default: 1.",
},
"timeout_ms": map[string]any{
"type": "integer",
"description": "Read/write timeout in milliseconds. Default: 1000.",
},
"length": map[string]any{
"type": "integer",
"description": "Number of bytes to read. Required for read. Range: 1-4096.",
},
"data": map[string]any{
"type": "array",
"items": map[string]any{"type": "integer"},
"description": "Bytes to write, each in range 0-255. Required for write unless text is provided.",
},
"text": map[string]any{
"type": "string",
"description": "UTF-8 text to write. Required for write if data is omitted.",
},
"confirm": map[string]any{
"type": "boolean",
"description": "Must be true for write operations.",
},
},
"required": []string{"action"},
}
}
func (t *SerialTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
action, ok := args["action"].(string)
if !ok || strings.TrimSpace(action) == "" {
return ErrorResult("action is required")
}
switch action {
case "list":
return t.list()
case "read":
return t.read(ctx, args)
case "write":
return t.write(ctx, args)
default:
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, read, write)", action))
}
}
func (t *SerialTool) list() *ToolResult {
ports, err := serialListPorts()
if err != nil {
return ErrorResult(fmt.Sprintf("failed to list serial ports: %v", err))
}
if len(ports) == 0 {
return SilentResult("No serial ports found on this host.")
}
result, _ := json.MarshalIndent(map[string]any{
"ports": ports,
"count": len(ports),
}, "", " ")
return SilentResult(string(result))
}
func (t *SerialTool) read(ctx context.Context, args map[string]any) *ToolResult {
cfg, errResult := parseSerialConfig(args)
if errResult != nil {
return errResult
}
length := 0
if v, ok := args["length"].(float64); ok {
length = int(v)
}
if length < 1 || length > maxSerialReadBytes {
return ErrorResult(fmt.Sprintf("length is required for read (1-%d)", maxSerialReadBytes))
}
timeout, errResult := parseSerialTimeout(args)
if errResult != nil {
return errResult
}
data, err := serialRead(ctx, cfg, length, timeout)
if err != nil {
return ErrorResult(fmt.Sprintf("serial read failed on %s: %v", cfg.Port, err))
}
return SilentResult(formatSerialPayload("read", cfg, data, timeout))
}
func (t *SerialTool) write(ctx context.Context, args map[string]any) *ToolResult {
confirm, _ := args["confirm"].(bool)
if !confirm {
return ErrorResult(
"write operations require confirm: true. Please confirm with the user before sending bytes to a serial device.",
)
}
cfg, errResult := parseSerialConfig(args)
if errResult != nil {
return errResult
}
timeout, errResult := parseSerialTimeout(args)
if errResult != nil {
return errResult
}
payload, errResult := parseSerialWritePayload(args)
if errResult != nil {
return errResult
}
written, err := serialWrite(ctx, cfg, payload, timeout)
if err != nil {
return ErrorResult(fmt.Sprintf("serial write failed on %s: %v", cfg.Port, err))
}
result, _ := json.MarshalIndent(map[string]any{
"action": "write",
"port": cfg.Port,
"baud": cfg.Baud,
"data_bits": cfg.DataBits,
"parity": cfg.Parity,
"stop_bits": cfg.StopBits,
"timeout_ms": timeout.Milliseconds(),
"written": written,
"payload": serialPayloadSummary(payload),
}, "", " ")
return SilentResult(string(result))
}
func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) {
port, ok := args["port"].(string)
port = strings.TrimSpace(port)
if !ok || port == "" {
return serialConfig{}, ErrorResult(
"port is required (for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3)",
)
}
normalizedPort, err := normalizeSerialPort(port)
if err != nil {
return serialConfig{}, ErrorResult(err.Error())
}
cfg := serialConfig{
Port: normalizedPort,
Baud: defaultSerialBaud,
DataBits: defaultSerialDataBits,
Parity: "none",
StopBits: defaultSerialStopBits,
}
if v, ok := args["baud"].(float64); ok {
cfg.Baud = int(v)
}
if err := validateSerialBaud(cfg.Baud); err != nil {
return serialConfig{}, ErrorResult(err.Error())
}
if v, ok := args["data_bits"].(float64); ok {
cfg.DataBits = int(v)
}
switch cfg.DataBits {
case 5, 6, 7, 8:
default:
return serialConfig{}, ErrorResult("data_bits must be one of 5, 6, 7, or 8")
}
if v, ok := args["parity"].(string); ok && strings.TrimSpace(v) != "" {
cfg.Parity = strings.ToLower(strings.TrimSpace(v))
}
switch cfg.Parity {
case "none", "even", "odd":
default:
return serialConfig{}, ErrorResult(`parity must be one of "none", "even", or "odd"`)
}
if v, ok := args["stop_bits"].(float64); ok {
cfg.StopBits = int(v)
}
if cfg.StopBits != 1 && cfg.StopBits != 2 {
return serialConfig{}, ErrorResult("stop_bits must be 1 or 2")
}
return cfg, nil
}
func parseSerialTimeout(args map[string]any) (time.Duration, *ToolResult) {
timeoutMS := defaultSerialTimeoutMS
if v, ok := args["timeout_ms"].(float64); ok {
timeoutMS = int(v)
}
if timeoutMS < 1 || timeoutMS > 60000 {
return 0, ErrorResult("timeout_ms must be between 1 and 60000")
}
return time.Duration(timeoutMS) * time.Millisecond, nil
}
func parseSerialWritePayload(args map[string]any) ([]byte, *ToolResult) {
if text, ok := args["text"].(string); ok && text != "" {
if !utf8.ValidString(text) {
return nil, ErrorResult("text must be valid UTF-8")
}
if len(text) > maxSerialPayloadBytes {
return nil, ErrorResult(fmt.Sprintf("text payload too large: maximum %d bytes", maxSerialPayloadBytes))
}
return []byte(text), nil
}
dataRaw, ok := args["data"].([]any)
if !ok || len(dataRaw) == 0 {
return nil, ErrorResult("write requires either text or data")
}
if len(dataRaw) > maxSerialPayloadBytes {
return nil, ErrorResult(fmt.Sprintf("data too long: maximum %d bytes", maxSerialPayloadBytes))
}
data := make([]byte, len(dataRaw))
for i, v := range dataRaw {
f, ok := v.(float64)
if !ok {
return nil, ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i))
}
if f != math.Trunc(f) {
return nil, ErrorResult(fmt.Sprintf("data[%d] is not an integer byte value", i))
}
b := int(f)
if b < 0 || b > 255 {
return nil, ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b))
}
data[i] = byte(b)
}
return data, nil
}
func formatSerialPayload(action string, cfg serialConfig, data []byte, timeout time.Duration) string {
result, _ := json.MarshalIndent(map[string]any{
"action": action,
"port": cfg.Port,
"baud": cfg.Baud,
"data_bits": cfg.DataBits,
"parity": cfg.Parity,
"stop_bits": cfg.StopBits,
"timeout_ms": timeout.Milliseconds(),
"payload": serialPayloadSummary(data),
}, "", " ")
return string(result)
}
func serialPayloadSummary(data []byte) map[string]any {
hexValues := make([]string, len(data))
intValues := make([]int, len(data))
for i, b := range data {
hexValues[i] = fmt.Sprintf("0x%02x", b)
intValues[i] = int(b)
}
summary := map[string]any{
"length": len(data),
"bytes": intValues,
"hex": hexValues,
}
if utf8.Valid(data) {
summary["text"] = string(data)
}
return summary
}
func normalizeSerialPort(port string) (string, error) {
switch runtime.GOOS {
case "windows":
return normalizeWindowsSerialPath(port)
case "linux", "darwin":
return normalizeUnixSerialPath(port)
default:
if normalized, err := normalizeUnixSerialPath(port); err == nil {
return normalized, nil
}
return normalizeWindowsSerialPath(port)
}
}
func normalizeUnixSerialPath(port string) (string, error) {
trimmed := strings.TrimSpace(port)
if !unixSerialPortPattern.MatchString(trimmed) {
return "", fmt.Errorf(
"invalid serial port: expected a safe Unix device name such as /dev/ttyUSB0 or /dev/cu.usbserial-0001",
)
}
if strings.HasPrefix(trimmed, "/dev/") {
return trimmed, nil
}
return "/dev/" + trimmed, nil
}
func normalizeWindowsSerialPath(port string) (string, error) {
trimmed := strings.ToUpper(strings.TrimSpace(port))
if !windowsSerialPortPattern.MatchString(trimmed) {
return "", fmt.Errorf("invalid serial port: expected a COM port such as COM3")
}
if strings.HasPrefix(trimmed, `\\.\`) {
return trimmed, nil
}
return `\\.\` + trimmed, nil
}
func validateSerialBaud(baud int) error {
if baud < 50 || baud > 4000000 {
return fmt.Errorf("baud must be between 50 and 4000000")
}
switch runtime.GOOS {
case "linux", "darwin":
if _, ok := unixSerialBaudRates[baud]; !ok {
return fmt.Errorf("unsupported baud rate on this platform: %d (supported up to 230400)", baud)
}
}
return nil
}
func serialContextErr(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
func serialWriteAll(
ctx context.Context,
data []byte,
timeout time.Duration,
now func() time.Time,
write func([]byte) (int, error),
) (int, error) {
if err := serialContextErr(ctx); err != nil {
return 0, err
}
total := 0
deadline := now().Add(timeout)
for total < len(data) {
if err := serialContextErr(ctx); err != nil {
return total, err
}
if deadline.Sub(now()) <= 0 {
return total, fmt.Errorf("timeout while writing serial data")
}
n, err := write(data[total:])
total += n
if err != nil {
return total, err
}
if n == 0 {
continue
}
}
return total, nil
}
+19
View File
@@ -0,0 +1,19 @@
//go:build darwin
package hardwaretools
import "golang.org/x/sys/unix"
func serialGetTermios(fd int) (*unix.Termios, error) {
return unix.IoctlGetTermios(fd, unix.TIOCGETA)
}
func serialSetSpeed(tio *unix.Termios, speed uint32) error {
tio.Ispeed = uint64(speed)
tio.Ospeed = uint64(speed)
return nil
}
func serialSetTermios(fd int, tio *unix.Termios) error {
return unix.IoctlSetTermios(fd, unix.TIOCSETA, tio)
}
+19
View File
@@ -0,0 +1,19 @@
//go:build linux
package hardwaretools
import "golang.org/x/sys/unix"
func serialGetTermios(fd int) (*unix.Termios, error) {
return unix.IoctlGetTermios(fd, unix.TCGETS)
}
func serialSetSpeed(tio *unix.Termios, speed uint32) error {
tio.Ispeed = speed
tio.Ospeed = speed
return nil
}
func serialSetTermios(fd int, tio *unix.Termios) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, tio)
}
+21
View File
@@ -0,0 +1,21 @@
//go:build !linux && !darwin && !windows
package hardwaretools
import (
"context"
"fmt"
"time"
)
func serialListPorts() ([]serialPortInfo, error) {
return nil, fmt.Errorf("serial is not supported on this platform")
}
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
return nil, fmt.Errorf("serial is not supported on this platform")
}
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
return 0, fmt.Errorf("serial is not supported on this platform")
}
+18
View File
@@ -0,0 +1,18 @@
//go:build !linux && !darwin && !windows
package hardwaretools
import (
"strings"
"testing"
)
func TestSerialListPortsUnsupportedPlatform(t *testing.T) {
_, err := serialListPorts()
if err == nil {
t.Fatal("expected unsupported platform error")
}
if !strings.Contains(err.Error(), "not supported") {
t.Fatalf("serialListPorts() error = %v, want unsupported platform message", err)
}
}
+269
View File
@@ -0,0 +1,269 @@
package hardwaretools
import (
"context"
"runtime"
"strings"
"testing"
"time"
)
func TestParseSerialConfig(t *testing.T) {
port := "/dev/ttyUSB0"
if runtime.GOOS == "windows" {
port = "COM3"
}
cfg, errResult := parseSerialConfig(map[string]any{
"port": port,
"baud": float64(9600),
"data_bits": float64(7),
"parity": "even",
"stop_bits": float64(2),
})
if errResult != nil {
t.Fatalf("parseSerialConfig() unexpected error = %v", errResult.ForLLM)
}
wantPort := "/dev/ttyUSB0"
if runtime.GOOS == "windows" {
wantPort = `\\.\COM3`
}
if cfg.Port != wantPort || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 {
t.Fatalf("parseSerialConfig() = %#v", cfg)
}
}
func TestParseSerialConfigRejectsInvalidParity(t *testing.T) {
port := "/dev/ttyUSB0"
if runtime.GOOS == "windows" {
port = "COM3"
}
_, errResult := parseSerialConfig(map[string]any{
"port": port,
"parity": "mark",
})
if errResult == nil {
t.Fatal("expected invalid parity to fail")
}
}
func TestParseSerialConfigRejectsUnsupportedUnixBaud(t *testing.T) {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
t.Skip("Unix baud validation only applies on Unix platforms")
}
_, errResult := parseSerialConfig(map[string]any{
"port": "/dev/ttyUSB0",
"baud": float64(460800),
})
if errResult == nil {
t.Fatal("expected unsupported Unix baud rate to fail")
}
}
func TestParseSerialWritePayloadRejectsFractionalBytes(t *testing.T) {
_, errResult := parseSerialWritePayload(map[string]any{
"data": []any{65.9},
})
if errResult == nil {
t.Fatal("expected fractional byte value to fail")
}
}
func TestValidateSerialBaud(t *testing.T) {
tests := []struct {
name string
baud int
wantErr bool
}{
{name: "default-supported", baud: 115200},
{name: "max-unix-supported", baud: 230400},
{name: "too-low", baud: 49, wantErr: true},
{name: "too-high", baud: 4000001, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateSerialBaud(tt.baud)
if (err != nil) != tt.wantErr {
t.Fatalf("validateSerialBaud(%d) error = %v, wantErr %v", tt.baud, err, tt.wantErr)
}
})
}
}
func TestSerialReadCanceledBeforeOpen(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
port := "/dev/ttyUSB0"
if runtime.GOOS == "windows" {
port = "COM3"
}
_, err := serialRead(
ctx,
serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1},
1,
time.Second,
)
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
t.Fatalf("serialRead() error = %v, want context canceled", err)
}
}
func TestSerialWriteCanceledBeforeOpen(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
port := "/dev/ttyUSB0"
if runtime.GOOS == "windows" {
port = "COM3"
}
_, err := serialWrite(
ctx,
serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1},
[]byte("AT"),
time.Second,
)
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
t.Fatalf("serialWrite() error = %v, want context canceled", err)
}
}
func TestParseSerialConfigRejectsUnsafePortPaths(t *testing.T) {
tests := []string{
"../../../etc/passwd",
"/etc/passwd",
`C:\temp\device.txt`,
`\\.\C:\temp\device.txt`,
}
for _, port := range tests {
t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) {
_, errResult := parseSerialConfig(map[string]any{
"port": port,
})
if errResult == nil {
t.Fatalf("expected unsafe port %q to be rejected", port)
}
})
}
}
func TestNormalizeUnixSerialPath(t *testing.T) {
tests := []struct {
port string
want string
}{
{port: "ttyUSB0", want: "/dev/ttyUSB0"},
{port: "/dev/ttyACM0", want: "/dev/ttyACM0"},
{port: "/dev/cu.usbserial-0001", want: "/dev/cu.usbserial-0001"},
}
for _, tt := range tests {
got, err := normalizeUnixSerialPath(tt.port)
if err != nil {
t.Fatalf("normalizeUnixSerialPath(%q) unexpected error = %v", tt.port, err)
}
if got != tt.want {
t.Fatalf("normalizeUnixSerialPath(%q) = %q, want %q", tt.port, got, tt.want)
}
}
}
func TestNormalizeUnixSerialPathRejectsInvalidNames(t *testing.T) {
tests := []string{
"",
"ttyUSB0/../../passwd",
"/dev/../../etc/passwd",
"/tmp/ttyUSB0",
"ttyUSB",
"COM3",
}
for _, port := range tests {
t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) {
if _, err := normalizeUnixSerialPath(port); err == nil {
t.Fatalf("expected %q to be rejected", port)
}
})
}
}
func TestNormalizeWindowsSerialPath(t *testing.T) {
tests := []struct {
port string
want string
}{
{port: "COM3", want: `\\.\COM3`},
{port: "com12", want: `\\.\COM12`},
{port: `\\.\COM7`, want: `\\.\COM7`},
}
for _, tt := range tests {
got, err := normalizeWindowsSerialPath(tt.port)
if err != nil {
t.Fatalf("normalizeWindowsSerialPath(%q) unexpected error = %v", tt.port, err)
}
if got != tt.want {
t.Fatalf("normalizeWindowsSerialPath(%q) = %q, want %q", tt.port, got, tt.want)
}
}
}
func TestNormalizeWindowsSerialPathRejectsInvalidNames(t *testing.T) {
tests := []string{
"",
"COM0",
"COM",
"/dev/ttyUSB0",
`C:\temp\device.txt`,
`\\.\C:\temp\device.txt`,
`\\server\share\COM3`,
}
for _, port := range tests {
t.Run(strings.ReplaceAll(strings.ReplaceAll(port, `\`, "_"), "/", "_"), func(t *testing.T) {
if _, err := normalizeWindowsSerialPath(port); err == nil {
t.Fatalf("expected %q to be rejected", port)
}
})
}
}
func TestParseSerialTimeout(t *testing.T) {
timeout, errResult := parseSerialTimeout(map[string]any{
"timeout_ms": float64(2500),
})
if errResult != nil {
t.Fatalf("parseSerialTimeout() unexpected error = %v", errResult.ForLLM)
}
if timeout != 2500*time.Millisecond {
t.Fatalf("timeout = %v, want 2500ms", timeout)
}
}
func TestParseSerialWritePayloadSupportsText(t *testing.T) {
data, errResult := parseSerialWritePayload(map[string]any{
"text": "AT\r\n",
})
if errResult != nil {
t.Fatalf("parseSerialWritePayload() unexpected error = %v", errResult.ForLLM)
}
if string(data) != "AT\r\n" {
t.Fatalf("payload = %q, want %q", string(data), "AT\r\n")
}
}
func TestParseSerialWritePayloadRejectsOutOfRangeByte(t *testing.T) {
_, errResult := parseSerialWritePayload(map[string]any{
"data": []any{float64(256)},
})
if errResult == nil {
t.Fatal("expected payload validation failure")
}
}
+286
View File
@@ -0,0 +1,286 @@
//go:build linux || darwin
package hardwaretools
import (
"context"
"fmt"
"os"
"path/filepath"
"sort"
"time"
"golang.org/x/sys/unix"
)
var (
unixSerialNow = time.Now
unixSerialOpenPort = openAndConfigureSerialPort
unixSerialClosePort = unix.Close
unixSerialPollRead = pollRead
unixSerialPollWrite = pollWrite
)
func serialListPorts() ([]serialPortInfo, error) {
patterns := []string{
"/dev/ttyS*",
"/dev/ttyUSB*",
"/dev/ttyACM*",
"/dev/ttyAMA*",
"/dev/rfcomm*",
"/dev/tty.*",
"/dev/cu.*",
}
seen := make(map[string]struct{})
ports := make([]serialPortInfo, 0)
for _, pattern := range patterns {
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, match := range matches {
if _, ok := seen[match]; ok {
continue
}
info, err := os.Stat(match)
if err != nil || info.IsDir() {
continue
}
seen[match] = struct{}{}
ports = append(ports, serialPortInfo{
Name: filepath.Base(match),
Path: match,
})
}
}
sort.Slice(ports, func(i, j int) bool {
return ports[i].Path < ports[j].Path
})
return ports, nil
}
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
if err := serialContextErr(ctx); err != nil {
return nil, err
}
fd, err := unixSerialOpenPort(cfg)
if err != nil {
return nil, err
}
defer unixSerialClosePort(fd)
buf := make([]byte, length)
total := 0
deadline := unixSerialNow().Add(timeout)
for total < length {
if err := serialContextErr(ctx); err != nil {
return nil, err
}
remaining := deadline.Sub(unixSerialNow())
if remaining <= 0 {
break
}
n, err := unixSerialPollRead(fd, buf[total:], minSerialPollTimeout(remaining))
if err != nil {
return nil, err
}
if n == 0 {
continue
}
total += n
}
return buf[:total], nil
}
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
if err := serialContextErr(ctx); err != nil {
return 0, err
}
fd, err := unixSerialOpenPort(cfg)
if err != nil {
return 0, err
}
defer unixSerialClosePort(fd)
total := 0
deadline := unixSerialNow().Add(timeout)
for total < len(data) {
if err := serialContextErr(ctx); err != nil {
return total, err
}
remaining := deadline.Sub(unixSerialNow())
if remaining <= 0 {
return total, fmt.Errorf("timeout while writing serial data")
}
n, err := unixSerialPollWrite(fd, data[total:], minSerialPollTimeout(remaining))
if err != nil {
return total, err
}
if n == 0 {
continue
}
total += n
}
return total, nil
}
func openAndConfigureSerialPort(cfg serialConfig) (int, error) {
fd, err := unix.Open(cfg.Port, unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0)
if err != nil {
return -1, err
}
if err := unix.SetNonblock(fd, false); err != nil {
unix.Close(fd)
return -1, err
}
if err := configureUnixSerialPort(fd, cfg); err != nil {
unix.Close(fd)
return -1, err
}
return fd, nil
}
func configureUnixSerialPort(fd int, cfg serialConfig) error {
tio, err := serialGetTermios(fd)
if err != nil {
return err
}
tio.Iflag = 0
tio.Oflag = 0
tio.Lflag = 0
tio.Cflag = unix.CREAD | unix.CLOCAL
tio.Cc[unix.VMIN] = 0
tio.Cc[unix.VTIME] = 0
switch cfg.DataBits {
case 5:
tio.Cflag |= unix.CS5
case 6:
tio.Cflag |= unix.CS6
case 7:
tio.Cflag |= unix.CS7
default:
tio.Cflag |= unix.CS8
}
switch cfg.Parity {
case "even":
tio.Cflag |= unix.PARENB
case "odd":
tio.Cflag |= unix.PARENB | unix.PARODD
}
if cfg.StopBits == 2 {
tio.Cflag |= unix.CSTOPB
}
speed, err := serialBaudToUnix(cfg.Baud)
if err != nil {
return err
}
if err := serialSetSpeed(tio, speed); err != nil {
return err
}
return serialSetTermios(fd, tio)
}
func serialBaudToUnix(baud int) (uint32, error) {
switch baud {
case 50:
return unix.B50, nil
case 75:
return unix.B75, nil
case 110:
return unix.B110, nil
case 134:
return unix.B134, nil
case 150:
return unix.B150, nil
case 200:
return unix.B200, nil
case 300:
return unix.B300, nil
case 600:
return unix.B600, nil
case 1200:
return unix.B1200, nil
case 1800:
return unix.B1800, nil
case 2400:
return unix.B2400, nil
case 4800:
return unix.B4800, nil
case 9600:
return unix.B9600, nil
case 19200:
return unix.B19200, nil
case 38400:
return unix.B38400, nil
case 57600:
return unix.B57600, nil
case 115200:
return unix.B115200, nil
case 230400:
return unix.B230400, nil
default:
return 0, fmt.Errorf("unsupported baud rate on this platform: %d", baud)
}
}
func pollRead(fd int, dst []byte, timeout time.Duration) (int, error) {
pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}}
n, err := unix.Poll(pfd, durationToPollTimeout(timeout))
if err != nil {
return 0, err
}
if n == 0 {
return 0, nil
}
return unix.Read(fd, dst)
}
func pollWrite(fd int, src []byte, timeout time.Duration) (int, error) {
pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLOUT}}
n, err := unix.Poll(pfd, durationToPollTimeout(timeout))
if err != nil {
return 0, err
}
if n == 0 {
return 0, nil
}
return unix.Write(fd, src)
}
func durationToPollTimeout(timeout time.Duration) int {
if timeout <= 0 {
return 0
}
ms := int(timeout / time.Millisecond)
if ms == 0 {
return 1
}
return ms
}
func minSerialPollTimeout(timeout time.Duration) time.Duration {
if timeout > serialPollInterval {
return serialPollInterval
}
return timeout
}
+140
View File
@@ -0,0 +1,140 @@
//go:build linux || darwin
package hardwaretools
import (
"context"
"errors"
"testing"
"time"
)
func stubUnixSerialIO(t *testing.T, now *time.Time) {
t.Helper()
prevNow := unixSerialNow
prevOpen := unixSerialOpenPort
prevClose := unixSerialClosePort
prevPollRead := unixSerialPollRead
prevPollWrite := unixSerialPollWrite
unixSerialNow = func() time.Time {
return *now
}
unixSerialOpenPort = func(cfg serialConfig) (int, error) {
return 42, nil
}
unixSerialClosePort = func(fd int) error {
return nil
}
unixSerialPollRead = prevPollRead
unixSerialPollWrite = prevPollWrite
t.Cleanup(func() {
unixSerialNow = prevNow
unixSerialOpenPort = prevOpen
unixSerialClosePort = prevClose
unixSerialPollRead = prevPollRead
unixSerialPollWrite = prevPollWrite
})
}
func TestSerialReadWaitsPastEmptyPollsUntilDeadline(t *testing.T) {
now := time.Unix(0, 0)
stubUnixSerialIO(t, &now)
pollCalls := 0
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
pollCalls++
if timeout > serialPollInterval {
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
}
now = now.Add(timeout)
if pollCalls < 4 {
return 0, nil
}
return copy(dst, []byte("OK")), nil
}
got, err := serialRead(context.Background(), serialConfig{}, 2, 500*time.Millisecond)
if err != nil {
t.Fatalf("serialRead() error = %v", err)
}
if string(got) != "OK" {
t.Fatalf("serialRead() = %q, want %q", got, "OK")
}
if pollCalls != 4 {
t.Fatalf("poll calls = %d, want 4", pollCalls)
}
}
func TestSerialReadReturnsPromptlyOnContextCancelBetweenPolls(t *testing.T) {
now := time.Unix(0, 0)
stubUnixSerialIO(t, &now)
ctx, cancel := context.WithCancel(context.Background())
pollCalls := 0
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
pollCalls++
now = now.Add(timeout)
cancel()
return 0, nil
}
_, err := serialRead(ctx, serialConfig{}, 1, time.Second)
if !errors.Is(err, context.Canceled) {
t.Fatalf("serialRead() error = %v, want context canceled", err)
}
if pollCalls != 1 {
t.Fatalf("poll calls = %d, want 1", pollCalls)
}
}
func TestSerialWriteWaitsPastEmptyPollsUntilReady(t *testing.T) {
now := time.Unix(0, 0)
stubUnixSerialIO(t, &now)
pollCalls := 0
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
pollCalls++
if timeout > serialPollInterval {
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
}
now = now.Add(timeout)
switch pollCalls {
case 1, 2:
return 0, nil
default:
return 1, nil
}
}
written, err := serialWrite(context.Background(), serialConfig{}, []byte("OK"), 500*time.Millisecond)
if err != nil {
t.Fatalf("serialWrite() error = %v", err)
}
if written != 2 {
t.Fatalf("serialWrite() wrote %d bytes, want 2", written)
}
if pollCalls != 4 {
t.Fatalf("poll calls = %d, want 4", pollCalls)
}
}
func TestSerialWriteTimesOutAfterRepeatedEmptyPolls(t *testing.T) {
now := time.Unix(0, 0)
stubUnixSerialIO(t, &now)
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
now = now.Add(timeout)
return 0, nil
}
written, err := serialWrite(context.Background(), serialConfig{}, []byte("A"), 250*time.Millisecond)
if err == nil || err.Error() != "timeout while writing serial data" {
t.Fatalf("serialWrite() error = %v, want timeout", err)
}
if written != 0 {
t.Fatalf("serialWrite() wrote %d bytes, want 0", written)
}
}
+248
View File
@@ -0,0 +1,248 @@
//go:build windows
package hardwaretools
import (
"context"
"fmt"
"sort"
"strings"
"time"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
procGetCommState = kernel32.NewProc("GetCommState")
procSetCommState = kernel32.NewProc("SetCommState")
procSetCommTimeouts = kernel32.NewProc("SetCommTimeouts")
procPurgeComm = kernel32.NewProc("PurgeComm")
)
const (
purgeTxClear = 0x0004
purgeRxClear = 0x0008
dcbFlagBinary = 0x00000001
dcbFlagParity = 0x00000002
dcbFlagOutxCtsFlow = 0x00000004
dcbFlagOutxDsrFlow = 0x00000008
dcbFlagDtrControlMask = 0x00000030
dcbFlagDsrSensitivity = 0x00000040
dcbFlagTXContinueOnXoff = 0x00000080
dcbFlagOutX = 0x00000100
dcbFlagInX = 0x00000200
dcbFlagRtsControlMask = 0x00003000
)
type dcb struct {
DCBlength uint32
BaudRate uint32
Flags uint32
Reserved uint16
XonLim uint16
XoffLim uint16
ByteSize byte
Parity byte
StopBits byte
XonChar byte
XoffChar byte
ErrorChar byte
EofChar byte
EvtChar byte
wReserved1 uint16
}
type commTimeouts struct {
ReadIntervalTimeout uint32
ReadTotalTimeoutMultiplier uint32
ReadTotalTimeoutConstant uint32
WriteTotalTimeoutMultiplier uint32
WriteTotalTimeoutConstant uint32
}
func serialListPorts() ([]serialPortInfo, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, `HARDWARE\DEVICEMAP\SERIALCOMM`, registry.QUERY_VALUE)
if err != nil {
if err == registry.ErrNotExist {
return nil, nil
}
return nil, err
}
defer key.Close()
names, err := key.ReadValueNames(-1)
if err != nil {
return nil, err
}
ports := make([]serialPortInfo, 0, len(names))
seen := make(map[string]struct{})
for _, name := range names {
value, _, err := key.GetStringValue(name)
if err != nil {
continue
}
portName := strings.TrimSpace(value)
if portName == "" {
continue
}
normalized := strings.ToUpper(portName)
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
ports = append(ports, serialPortInfo{
Name: normalized,
Path: normalized,
})
}
sort.Slice(ports, func(i, j int) bool {
return ports[i].Path < ports[j].Path
})
return ports, nil
}
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
if err := serialContextErr(ctx); err != nil {
return nil, err
}
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
if err != nil {
return nil, err
}
defer windows.CloseHandle(handle)
if err := serialContextErr(ctx); err != nil {
return nil, err
}
buf := make([]byte, length)
var read uint32
// Synchronous serial I/O on Windows cannot be interrupted once the syscall starts.
// COMMTIMEOUTS bounds how long turn cancellation may take to surface.
if err := windows.ReadFile(handle, buf, &read, nil); err != nil {
return nil, err
}
return buf[:read], nil
}
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
if err := serialContextErr(ctx); err != nil {
return 0, err
}
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
if err != nil {
return 0, err
}
defer windows.CloseHandle(handle)
if err := serialContextErr(ctx); err != nil {
return 0, err
}
return serialWriteAll(ctx, data, timeout, time.Now, func(chunk []byte) (int, error) {
var written uint32
// Like ReadFile above, this synchronous WriteFile call relies on COMMTIMEOUTS
// rather than context preemption once the syscall is in flight.
if err := windows.WriteFile(handle, chunk, &written, nil); err != nil {
return int(written), err
}
return int(written), nil
})
}
func openAndConfigureWindowsSerial(cfg serialConfig, timeout time.Duration) (windows.Handle, error) {
handle, err := windows.CreateFile(
windows.StringToUTF16Ptr(cfg.Port),
windows.GENERIC_READ|windows.GENERIC_WRITE,
0,
nil,
windows.OPEN_EXISTING,
0,
0,
)
if err != nil {
return 0, err
}
if err := configureWindowsSerialPort(handle, cfg, timeout); err != nil {
windows.CloseHandle(handle)
return 0, err
}
return handle, nil
}
func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout time.Duration) error {
state := dcb{DCBlength: uint32(unsafe.Sizeof(dcb{}))}
r1, _, err := procGetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state)))
if r1 == 0 {
return err
}
state.BaudRate = uint32(cfg.Baud)
state.ByteSize = byte(cfg.DataBits)
state.Flags = sanitizeWindowsSerialFlags(state.Flags)
state.Flags |= dcbFlagBinary
switch cfg.Parity {
case "even":
state.Parity = 2
state.Flags |= dcbFlagParity
case "odd":
state.Parity = 1
state.Flags |= dcbFlagParity
default:
state.Parity = 0
state.Flags &^= dcbFlagParity
}
switch cfg.StopBits {
case 2:
state.StopBits = 2
default:
state.StopBits = 0
}
r1, _, err = procSetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state)))
if r1 == 0 {
return err
}
timeoutMS := uint32(timeout / time.Millisecond)
if timeoutMS == 0 {
timeoutMS = 1
}
timeouts := commTimeouts{
ReadIntervalTimeout: timeoutMS,
ReadTotalTimeoutConstant: timeoutMS,
WriteTotalTimeoutConstant: timeoutMS,
ReadTotalTimeoutMultiplier: 0,
WriteTotalTimeoutMultiplier: 0,
}
r1, _, err = procSetCommTimeouts.Call(uintptr(handle), uintptr(unsafe.Pointer(&timeouts)))
if r1 == 0 {
return err
}
procPurgeComm.Call(uintptr(handle), uintptr(purgeRxClear|purgeTxClear))
return nil
}
func sanitizeWindowsSerialFlags(flags uint32) uint32 {
flags &^= dcbFlagOutxCtsFlow |
dcbFlagOutxDsrFlow |
dcbFlagDtrControlMask |
dcbFlagDsrSensitivity |
dcbFlagTXContinueOnXoff |
dcbFlagOutX |
dcbFlagInX |
dcbFlagRtsControlMask
return flags
}
+39
View File
@@ -0,0 +1,39 @@
//go:build windows
package hardwaretools
import "testing"
func TestSanitizeWindowsSerialFlags(t *testing.T) {
flags := uint32(
dcbFlagBinary |
dcbFlagParity |
dcbFlagOutxCtsFlow |
dcbFlagOutxDsrFlow |
dcbFlagDtrControlMask |
dcbFlagDsrSensitivity |
dcbFlagTXContinueOnXoff |
dcbFlagOutX |
dcbFlagInX |
dcbFlagRtsControlMask,
)
got := sanitizeWindowsSerialFlags(flags)
if got&dcbFlagBinary == 0 {
t.Fatal("sanitizeWindowsSerialFlags() should preserve fBinary")
}
if got&dcbFlagParity == 0 {
t.Fatal("sanitizeWindowsSerialFlags() should preserve fParity")
}
if got&(dcbFlagOutxCtsFlow|
dcbFlagOutxDsrFlow|
dcbFlagDtrControlMask|
dcbFlagDsrSensitivity|
dcbFlagTXContinueOnXoff|
dcbFlagOutX|
dcbFlagInX|
dcbFlagRtsControlMask) != 0 {
t.Fatalf("sanitizeWindowsSerialFlags() = %#x, want flow-control bits cleared", got)
}
}
@@ -0,0 +1,87 @@
package hardwaretools
import (
"context"
"errors"
"testing"
"time"
)
func TestSerialWriteAllRetriesPartialWritesUntilComplete(t *testing.T) {
now := time.Unix(0, 0)
calls := 0
written, err := serialWriteAll(context.Background(), []byte("PING"), time.Second, func() time.Time {
return now
}, func(chunk []byte) (int, error) {
calls++
now = now.Add(100 * time.Millisecond)
switch calls {
case 1:
if string(chunk) != "PING" {
t.Fatalf("first chunk = %q, want %q", chunk, "PING")
}
return 2, nil
case 2:
if string(chunk) != "NG" {
t.Fatalf("second chunk = %q, want %q", chunk, "NG")
}
return 2, nil
default:
t.Fatalf("unexpected extra write call %d", calls)
return 0, nil
}
})
if err != nil {
t.Fatalf("serialWriteAll() error = %v", err)
}
if written != 4 {
t.Fatalf("serialWriteAll() wrote %d bytes, want 4", written)
}
}
func TestSerialWriteAllTimesOutAfterZeroByteWrites(t *testing.T) {
now := time.Unix(0, 0)
calls := 0
written, err := serialWriteAll(context.Background(), []byte("A"), 250*time.Millisecond, func() time.Time {
return now
}, func(chunk []byte) (int, error) {
calls++
now = now.Add(100 * time.Millisecond)
return 0, nil
})
if err == nil || err.Error() != "timeout while writing serial data" {
t.Fatalf("serialWriteAll() error = %v, want timeout", err)
}
if written != 0 {
t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written)
}
if calls != 3 {
t.Fatalf("write calls = %d, want 3", calls)
}
}
func TestSerialWriteAllReturnsContextCancellationAfterRetryBoundary(t *testing.T) {
now := time.Unix(0, 0)
ctx, cancel := context.WithCancel(context.Background())
calls := 0
written, err := serialWriteAll(ctx, []byte("A"), time.Second, func() time.Time {
return now
}, func(chunk []byte) (int, error) {
calls++
now = now.Add(100 * time.Millisecond)
cancel()
return 0, nil
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("serialWriteAll() error = %v, want context canceled", err)
}
if written != 0 {
t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written)
}
if calls != 1 {
t.Fatalf("write calls = %d, want 1", calls)
}
}
+7 -2
View File
@@ -3,8 +3,9 @@ package tools
import hardwaretools "github.com/sipeed/picoclaw/pkg/tools/hardware"
type (
I2CTool = hardwaretools.I2CTool
SPITool = hardwaretools.SPITool
I2CTool = hardwaretools.I2CTool
SerialTool = hardwaretools.SerialTool
SPITool = hardwaretools.SPITool
)
func NewI2CTool() *I2CTool {
@@ -14,3 +15,7 @@ func NewI2CTool() *I2CTool {
func NewSPITool() *SPITool {
return hardwaretools.NewSPITool()
}
func NewSerialTool() *SerialTool {
return hardwaretools.NewSerialTool()
}
+22
View File
@@ -171,6 +171,12 @@ var toolCatalog = []toolCatalogEntry{
Category: "hardware",
ConfigKey: "spi",
},
{
Name: "serial",
Description: "Interact with serial ports exposed on the host.",
Category: "hardware",
ConfigKey: "serial",
},
{
Name: "tool_search_tool_regex",
Description: "Discover hidden MCP tools by regex search when tool discovery is enabled.",
@@ -265,6 +271,8 @@ func buildToolSupport(cfg *config.Config) []toolSupportItem {
status, reasonCode = resolveWebSearchToolSupport(cfg)
case "i2c", "spi":
status, reasonCode = resolveHardwareToolSupport(cfg.Tools.IsToolEnabled(entry.ConfigKey))
case "serial":
status, reasonCode = resolveSerialToolSupport(cfg.Tools.IsToolEnabled(entry.ConfigKey))
default:
if cfg.Tools.IsToolEnabled(entry.ConfigKey) {
status = "enabled"
@@ -293,6 +301,18 @@ func resolveHardwareToolSupport(enabled bool) (string, string) {
return "enabled", ""
}
func resolveSerialToolSupport(enabled bool) (string, string) {
if !enabled {
return "disabled", ""
}
switch runtime.GOOS {
case "linux", "darwin", "windows":
return "enabled", ""
default:
return "blocked", "requires_serial_platform"
}
}
func resolveDiscoveryToolSupport(cfg *config.Config, methodEnabled bool) (string, string) {
if !cfg.Tools.IsToolEnabled("mcp") {
return "disabled", ""
@@ -362,6 +382,8 @@ func applyToolState(cfg *config.Config, toolName string, enabled bool) error {
cfg.Tools.I2C.Enabled = enabled
case "spi":
cfg.Tools.SPI.Enabled = enabled
case "serial":
cfg.Tools.Serial.Enabled = enabled
case "tool_search_tool_regex":
cfg.Tools.MCP.Discovery.UseRegex = enabled
if enabled {
+57
View File
@@ -92,9 +92,36 @@ func TestHandleListTools(t *testing.T) {
if gotTools["i2c"].Status != "disabled" {
t.Fatalf("i2c status = %q, want disabled on linux when config is off", gotTools["i2c"].Status)
}
if gotTools["serial"].Status != "disabled" {
t.Fatalf("serial status = %q, want disabled when config is off", gotTools["serial"].Status)
}
cfg.Tools.Serial.Enabled = true
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/tools", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
gotTools = make(map[string]toolSupportItem, len(resp.Tools))
for _, tool := range resp.Tools {
gotTools[tool.Name] = tool
}
if gotTools["serial"].Status != "enabled" {
t.Fatalf("serial = %#v, want enabled on linux when config is on", gotTools["serial"])
}
} else {
cfg.Tools.I2C.Enabled = true
cfg.Tools.SPI.Enabled = true
cfg.Tools.Serial.Enabled = true
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
@@ -120,6 +147,16 @@ func TestHandleListTools(t *testing.T) {
if gotTools["spi"].Status != "blocked" || gotTools["spi"].ReasonCode != "requires_linux" {
t.Fatalf("spi = %#v, want blocked/requires_linux", gotTools["spi"])
}
switch runtime.GOOS {
case "darwin", "windows":
if gotTools["serial"].Status != "enabled" {
t.Fatalf("serial = %#v, want enabled on supported host", gotTools["serial"])
}
default:
if gotTools["serial"].Status != "blocked" || gotTools["serial"].ReasonCode != "requires_serial_platform" {
t.Fatalf("serial = %#v, want blocked/requires_serial_platform", gotTools["serial"])
}
}
}
}
@@ -195,6 +232,26 @@ func TestHandleUpdateToolState(t *testing.T) {
if !updated.Tools.Cron.Enabled {
t.Fatalf("cron should be enabled: %#v", updated.Tools.Cron)
}
rec4 := httptest.NewRecorder()
req4 := httptest.NewRequest(
http.MethodPut,
"/api/tools/serial/state",
bytes.NewBufferString(`{"enabled":true}`),
)
req4.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec4, req4)
if rec4.Code != http.StatusOK {
t.Fatalf("serial status = %d, want %d, body=%s", rec4.Code, http.StatusOK, rec4.Body.String())
}
updated, err = config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig(updated serial) error = %v", err)
}
if !updated.Tools.Serial.Enabled {
t.Fatalf("serial should be enabled: %#v", updated.Tools.Serial)
}
}
func TestHandleListTools_ReportsWebSearchEnabledWhenToolIsOn(t *testing.T) {
+1
View File
@@ -603,6 +603,7 @@
},
"reasons": {
"requires_linux": "This tool only works on Linux hosts with the required device files exposed.",
"requires_serial_platform": "This tool currently supports Linux, macOS, and Windows hosts with accessible serial ports.",
"requires_skills": "Enable `tools.skills` before this skill-registry tool can be used.",
"requires_subagent": "Enable `tools.subagent` before the spawn tool can delegate work.",
"requires_mcp_discovery": "Enable `tools.mcp.discovery` before MCP discovery tools become available.",
+1
View File
@@ -603,6 +603,7 @@
},
"reasons": {
"requires_linux": "该工具仅在 Linux 主机上可用,并且需要暴露对应的设备文件。",
"requires_serial_platform": "该工具当前支持 Linux、macOS 和 Windows,且要求主机可访问对应串口。",
"requires_skills": "需要先启用 `tools.skills`,该技能注册表工具才能使用。",
"requires_subagent": "需要先启用 `tools.subagent``spawn` 才能委派任务。",
"requires_mcp_discovery": "需要先启用 `tools.mcp.discovery`MCP 发现工具才会可用。",