mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
Add cross-platform serial tool support (#2673)
* feat(tools): add cross-platform serial hardware tool * feat(config): wire serial tool into runtime and dashboard * hardware/serial: tighten validation and error handling * hardware/serial: improve unix cancellation and timeout polling * hardware/serial: improve windows I/O handling * hardware/serial: fix darwin cross-compilation build * docs(design): summarize hardware support and serial limits * build: keep go generate on host during cross builds * onboard: drop unrelated go generate change from serial work * style(tools): wrap serial lines for golines
This commit is contained in:
@@ -171,6 +171,18 @@ ifeq ($(OS),Windows_NT)
|
||||
EXT=.exe
|
||||
endif
|
||||
|
||||
ifneq ($(strip $(GOOS)),)
|
||||
PLATFORM:=$(GOOS)
|
||||
endif
|
||||
|
||||
ifneq ($(strip $(GOARCH)),)
|
||||
ARCH:=$(GOARCH)
|
||||
endif
|
||||
|
||||
ifeq ($(PLATFORM),windows)
|
||||
EXT=.exe
|
||||
endif
|
||||
|
||||
BINARY_PATH=$(BUILD_DIR)/$(BINARY_NAME)-$(PLATFORM)-$(ARCH)
|
||||
|
||||
# Default target
|
||||
@@ -181,10 +193,11 @@ generate:
|
||||
@echo "Run generate..."
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@$(POWERSHELL) "if (Test-Path -LiteralPath './$(CMD_DIR)/workspace') { Remove-Item -LiteralPath './$(CMD_DIR)/workspace' -Recurse -Force }"
|
||||
@$(POWERSHELL) "$$env:GOOS=''; $$env:GOARCH=''; $(GO) generate ./..."
|
||||
else
|
||||
@rm -r ./$(CMD_DIR)/workspace 2>/dev/null || true
|
||||
@GOOS=$$($(GO) env GOHOSTOS) GOARCH=$$($(GO) env GOHOSTARCH) $(GO) generate ./...
|
||||
endif
|
||||
@$(GO) generate ./...
|
||||
@echo "Run generate complete"
|
||||
|
||||
## build: Build the picoclaw binary for current platform
|
||||
@@ -196,7 +209,7 @@ ifeq ($(OS),Windows_NT)
|
||||
@$(POWERSHELL) "Copy-Item -LiteralPath '$(BINARY_PATH)$(EXT)' -Destination '$(BUILD_DIR)/$(BINARY_NAME)$(EXT)' -Force"
|
||||
else
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
@GOARCH=${ARCH} $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BINARY_PATH)$(EXT) ./$(CMD_DIR)
|
||||
@GOOS=$(PLATFORM) GOARCH=$(ARCH) $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BINARY_PATH)$(EXT) ./$(CMD_DIR)
|
||||
@echo "Build complete: $(BINARY_PATH)$(EXT)"
|
||||
@$(LNCMD) $(BINARY_NAME)-$(PLATFORM)-$(ARCH)$(EXT) $(BUILD_DIR)/$(BINARY_NAME)$(EXT)
|
||||
endif
|
||||
@@ -211,7 +224,7 @@ ifeq ($(OS),Windows_NT)
|
||||
@$(POWERSHELL) "Copy-Item -LiteralPath '$(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH)$(EXT)' -Destination '$(BUILD_DIR)/picoclaw-launcher$(EXT)' -Force"
|
||||
else
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
@GOARCH=${ARCH} $(MAKE) -C web build \
|
||||
@GOOS=$(PLATFORM) GOARCH=$(ARCH) $(MAKE) -C web build \
|
||||
OUTPUT="$(CURDIR)/$(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH)$(EXT)" \
|
||||
WEB_GO='$(WEB_GO)' \
|
||||
GO_BUILD_TAGS='$(GO_BUILD_TAGS)' \
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
//go:generate go run ../../../../scripts/copydir.go "${DOLLAR}{codespace}/workspace" ./workspace
|
||||
//go:generate go run ../../../../scripts/copydir.go ../../../../workspace ./workspace
|
||||
//go:embed workspace
|
||||
var embeddedFiles embed.FS
|
||||
|
||||
|
||||
@@ -437,6 +437,9 @@
|
||||
"enabled": true,
|
||||
"mode": "bytes"
|
||||
},
|
||||
"serial": {
|
||||
"enabled": false
|
||||
},
|
||||
"send_tts": {
|
||||
"enabled": false
|
||||
},
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
# 当前硬件支持现状与串口 Tool 方案
|
||||
|
||||
## 现状结论
|
||||
|
||||
当前项目已有的硬件相关能力主要分为两条线:
|
||||
|
||||
1. 设备事件监控
|
||||
- `pkg/devices` 已实现设备事件服务。
|
||||
- 当前只有 Linux USB 热插拔事件源 `pkg/devices/sources/usb_linux.go`。
|
||||
- 能力定位是“发现和通知”,不是“总线读写控制”。
|
||||
|
||||
2. 硬件控制 Tool
|
||||
- `pkg/tools/hardware/i2c*.go`:I2C Tool,支持 `detect`、`scan`、`read`、`write`。
|
||||
- `pkg/tools/hardware/spi*.go`:SPI Tool,支持 `list`、`transfer`、`read`。
|
||||
- 这两类 Tool 当前都只在 Linux 主机上启用,直接依赖 `/dev/i2c-*` 与 `/dev/spidev*`。
|
||||
|
||||
因此,项目在“硬件支持能力”上已经具备:
|
||||
|
||||
- Linux USB 设备插拔感知
|
||||
- Linux I2C 总线控制
|
||||
- Linux SPI 总线控制
|
||||
|
||||
但还缺少:
|
||||
|
||||
- 串口/UART 控制
|
||||
- macOS / Windows 下可直接使用的硬件控制 Tool
|
||||
- 面向统一硬件抽象的跨总线能力模型
|
||||
|
||||
## 本次新增
|
||||
|
||||
本次新增内建 `serial` Tool,并接入现有 Tool 体系:
|
||||
|
||||
- 配置项:`tools.serial.enabled`
|
||||
- Tool 注册:`pkg/agent/agent_init.go`
|
||||
- Web 工具页:`/api/tools` 能展示与切换 `serial`
|
||||
- 前端状态文案:新增 `requires_serial_platform`
|
||||
|
||||
## Serial Tool 设计
|
||||
|
||||
`serial` 采用无状态调用模型,每次请求都自行打开和关闭端口,避免在 agent 回合之间维护串口会话状态。
|
||||
|
||||
支持动作:
|
||||
|
||||
- `list`:枚举主机串口
|
||||
- `read`:从串口读取指定长度字节
|
||||
- `write`:向串口写入字节或文本
|
||||
|
||||
公共参数:
|
||||
|
||||
- `port`
|
||||
- `baud`
|
||||
- `data_bits`
|
||||
- `parity`
|
||||
- `stop_bits`
|
||||
- `timeout_ms`
|
||||
|
||||
当前波特率实现边界:
|
||||
|
||||
- Windows 允许配置工具层接受的范围 `50-4000000`
|
||||
- Linux / macOS 当前仅支持标准 termios 波特率,实际支持到 `230400`
|
||||
- 因此 `baud` 的跨平台可移植取值应优先使用 `230400` 及以下的常见标准速率
|
||||
|
||||
安全约束:
|
||||
|
||||
- `write` 必须显式传 `confirm: true`
|
||||
- 单次读写负载限制为 `4096` 字节
|
||||
- `port` 只接受白名单串口名:
|
||||
- Linux / macOS 仅允许 `/dev/tty*`、`/dev/cu.*` 及对应简写设备名
|
||||
- Windows 仅允许 `COM\d+` 或 `\\.\COM\d+`
|
||||
- 明确拒绝 `..`、普通文件绝对路径、盘符路径等非串口设备路径,避免路径穿越或误打开任意文件
|
||||
|
||||
## 跨平台实现边界
|
||||
|
||||
- Linux / macOS:
|
||||
- 基于 `golang.org/x/sys/unix` 和 termios 配置串口参数。
|
||||
- 当前仅接入标准 termios 波特率映射,最高到 `230400`,尚未扩展 `460800`、`921600`、`1000000`、`2000000` 等更高速率。
|
||||
- 通过 `/dev/...` 枚举和访问设备。
|
||||
|
||||
- Windows:
|
||||
- 基于 `kernel32` 串口 API 配置 `DCB` 和 `COMMTIMEOUTS`。
|
||||
- 当前读写仍使用同步 `ReadFile` / `WriteFile`;一旦 syscall 已进入执行,turn context cancellation 不能立即打断,只能等待 `COMMTIMEOUTS` 触发后返回。
|
||||
- 通过注册表 `HARDWARE\\DEVICEMAP\\SERIALCOMM` 枚举端口。
|
||||
|
||||
- 其他平台:
|
||||
- `serial` Tool 显式返回 unsupported,不做静默降级。
|
||||
|
||||
## 后续建议
|
||||
|
||||
1. 如果需要持续交互式串口会话,建议再增加 session 型 Tool,而不是让 LLM 反复做短连接轮询。
|
||||
2. 如果后续要支持 CAN、GPIO、PWM,建议抽出统一的硬件 capability 描述层,而不是继续只靠 Tool 名称区分。
|
||||
3. 若需要生产级稳定性,建议补真实串口回环测试,至少覆盖 Linux PTY 和 Windows COM 模拟场景。
|
||||
@@ -128,6 +128,9 @@ func registerSharedTools(
|
||||
if cfg.Tools.IsToolEnabled("spi") {
|
||||
agent.Tools.Register(tools.NewSPITool())
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("serial") {
|
||||
agent.Tools.Register(tools.NewSerialTool())
|
||||
}
|
||||
|
||||
// Message tool
|
||||
if cfg.Tools.IsToolEnabled("message") {
|
||||
|
||||
@@ -823,6 +823,7 @@ type ToolsConfig struct {
|
||||
ListDir ToolConfig `json:"list_dir" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
|
||||
Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
Serial ToolConfig `json:"serial" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SERIAL_"`
|
||||
SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"`
|
||||
Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
@@ -1548,6 +1549,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
return t.Message.Enabled
|
||||
case "read_file":
|
||||
return t.ReadFile.Enabled
|
||||
case "serial":
|
||||
return t.Serial.Enabled
|
||||
case "spawn":
|
||||
return t.Spawn.Enabled
|
||||
case "spawn_status":
|
||||
|
||||
@@ -435,6 +435,9 @@ func DefaultConfig() *Config {
|
||||
Mode: ReadFileModeBytes,
|
||||
MaxReadFileSize: 64 * 1024, // 64KB
|
||||
},
|
||||
Serial: ToolConfig{
|
||||
Enabled: false, // Hardware tool - requires host serial ports
|
||||
},
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
|
||||
@@ -9,6 +9,9 @@ func TestFacadeConstructorsRemainAvailable(t *testing.T) {
|
||||
if NewSPITool() == nil {
|
||||
t.Fatal("NewSPITool should return a tool")
|
||||
}
|
||||
if NewSerialTool() == nil {
|
||||
t.Fatal("NewSerialTool should return a tool")
|
||||
}
|
||||
if NewMessageTool() == nil {
|
||||
t.Fatal("NewMessageTool should return a tool")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,453 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSerialBaud = 115200
|
||||
defaultSerialDataBits = 8
|
||||
defaultSerialStopBits = 1
|
||||
defaultSerialTimeoutMS = 1000
|
||||
maxSerialPayloadBytes = 4096
|
||||
maxSerialReadBytes = 4096
|
||||
serialPollInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
unixSerialPortPattern = regexp.MustCompile(
|
||||
`^(?:/dev/)?(?:ttyS\d+|ttyUSB\d+|ttyACM\d+|ttyAMA\d+|rfcomm\d+|tty\.[A-Za-z0-9._-]+|cu\.[A-Za-z0-9._-]+)$`,
|
||||
)
|
||||
windowsSerialPortPattern = regexp.MustCompile(`^(?:\\\\\.\\)?COM[1-9]\d*$`)
|
||||
unixSerialBaudRates = map[int]struct{}{
|
||||
50: {}, 75: {}, 110: {}, 134: {}, 150: {}, 200: {}, 300: {}, 600: {}, 1200: {}, 1800: {},
|
||||
2400: {}, 4800: {}, 9600: {}, 19200: {}, 38400: {}, 57600: {}, 115200: {}, 230400: {},
|
||||
}
|
||||
)
|
||||
|
||||
type SerialTool struct{}
|
||||
|
||||
type serialPortInfo struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type serialConfig struct {
|
||||
Port string
|
||||
Baud int
|
||||
DataBits int
|
||||
Parity string
|
||||
StopBits int
|
||||
}
|
||||
|
||||
func NewSerialTool() *SerialTool {
|
||||
return &SerialTool{}
|
||||
}
|
||||
|
||||
func (t *SerialTool) Name() string {
|
||||
return "serial"
|
||||
}
|
||||
|
||||
func (t *SerialTool) Description() string {
|
||||
return "Interact with host serial ports. Actions: list (enumerate ports), read (receive bytes), write (send bytes with explicit confirmation). Supports Linux, macOS, and Windows."
|
||||
}
|
||||
|
||||
func (t *SerialTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"action": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"list", "read", "write"},
|
||||
"description": "Action to perform: list available serial ports, read bytes from a port, or write bytes to a port.",
|
||||
},
|
||||
"port": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Serial port path or name, for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3. Required for read/write.",
|
||||
},
|
||||
"baud": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Baud rate. Default: 115200. Linux/macOS currently support standard termios rates up to 230400; Windows accepts configured rates up to 4000000.",
|
||||
},
|
||||
"data_bits": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Data bits. Supported values: 5, 6, 7, 8. Default: 8.",
|
||||
},
|
||||
"parity": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"none", "even", "odd"},
|
||||
"description": "Parity mode. Default: none.",
|
||||
},
|
||||
"stop_bits": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Stop bits. Supported values: 1, 2. Default: 1.",
|
||||
},
|
||||
"timeout_ms": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Read/write timeout in milliseconds. Default: 1000.",
|
||||
},
|
||||
"length": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Number of bytes to read. Required for read. Range: 1-4096.",
|
||||
},
|
||||
"data": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{"type": "integer"},
|
||||
"description": "Bytes to write, each in range 0-255. Required for write unless text is provided.",
|
||||
},
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "UTF-8 text to write. Required for write if data is omitted.",
|
||||
},
|
||||
"confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Must be true for write operations.",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SerialTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok || strings.TrimSpace(action) == "" {
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
return t.list()
|
||||
case "read":
|
||||
return t.read(ctx, args)
|
||||
case "write":
|
||||
return t.write(ctx, args)
|
||||
default:
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, read, write)", action))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SerialTool) list() *ToolResult {
|
||||
ports, err := serialListPorts()
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to list serial ports: %v", err))
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
return SilentResult("No serial ports found on this host.")
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"ports": ports,
|
||||
"count": len(ports),
|
||||
}, "", " ")
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
|
||||
func (t *SerialTool) read(ctx context.Context, args map[string]any) *ToolResult {
|
||||
cfg, errResult := parseSerialConfig(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
length := 0
|
||||
if v, ok := args["length"].(float64); ok {
|
||||
length = int(v)
|
||||
}
|
||||
if length < 1 || length > maxSerialReadBytes {
|
||||
return ErrorResult(fmt.Sprintf("length is required for read (1-%d)", maxSerialReadBytes))
|
||||
}
|
||||
|
||||
timeout, errResult := parseSerialTimeout(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
data, err := serialRead(ctx, cfg, length, timeout)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("serial read failed on %s: %v", cfg.Port, err))
|
||||
}
|
||||
|
||||
return SilentResult(formatSerialPayload("read", cfg, data, timeout))
|
||||
}
|
||||
|
||||
func (t *SerialTool) write(ctx context.Context, args map[string]any) *ToolResult {
|
||||
confirm, _ := args["confirm"].(bool)
|
||||
if !confirm {
|
||||
return ErrorResult(
|
||||
"write operations require confirm: true. Please confirm with the user before sending bytes to a serial device.",
|
||||
)
|
||||
}
|
||||
|
||||
cfg, errResult := parseSerialConfig(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
timeout, errResult := parseSerialTimeout(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
payload, errResult := parseSerialWritePayload(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
written, err := serialWrite(ctx, cfg, payload, timeout)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("serial write failed on %s: %v", cfg.Port, err))
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"action": "write",
|
||||
"port": cfg.Port,
|
||||
"baud": cfg.Baud,
|
||||
"data_bits": cfg.DataBits,
|
||||
"parity": cfg.Parity,
|
||||
"stop_bits": cfg.StopBits,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
"written": written,
|
||||
"payload": serialPayloadSummary(payload),
|
||||
}, "", " ")
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
|
||||
func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) {
|
||||
port, ok := args["port"].(string)
|
||||
port = strings.TrimSpace(port)
|
||||
if !ok || port == "" {
|
||||
return serialConfig{}, ErrorResult(
|
||||
"port is required (for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3)",
|
||||
)
|
||||
}
|
||||
|
||||
normalizedPort, err := normalizeSerialPort(port)
|
||||
if err != nil {
|
||||
return serialConfig{}, ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
cfg := serialConfig{
|
||||
Port: normalizedPort,
|
||||
Baud: defaultSerialBaud,
|
||||
DataBits: defaultSerialDataBits,
|
||||
Parity: "none",
|
||||
StopBits: defaultSerialStopBits,
|
||||
}
|
||||
|
||||
if v, ok := args["baud"].(float64); ok {
|
||||
cfg.Baud = int(v)
|
||||
}
|
||||
if err := validateSerialBaud(cfg.Baud); err != nil {
|
||||
return serialConfig{}, ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if v, ok := args["data_bits"].(float64); ok {
|
||||
cfg.DataBits = int(v)
|
||||
}
|
||||
switch cfg.DataBits {
|
||||
case 5, 6, 7, 8:
|
||||
default:
|
||||
return serialConfig{}, ErrorResult("data_bits must be one of 5, 6, 7, or 8")
|
||||
}
|
||||
|
||||
if v, ok := args["parity"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
cfg.Parity = strings.ToLower(strings.TrimSpace(v))
|
||||
}
|
||||
switch cfg.Parity {
|
||||
case "none", "even", "odd":
|
||||
default:
|
||||
return serialConfig{}, ErrorResult(`parity must be one of "none", "even", or "odd"`)
|
||||
}
|
||||
|
||||
if v, ok := args["stop_bits"].(float64); ok {
|
||||
cfg.StopBits = int(v)
|
||||
}
|
||||
if cfg.StopBits != 1 && cfg.StopBits != 2 {
|
||||
return serialConfig{}, ErrorResult("stop_bits must be 1 or 2")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func parseSerialTimeout(args map[string]any) (time.Duration, *ToolResult) {
|
||||
timeoutMS := defaultSerialTimeoutMS
|
||||
if v, ok := args["timeout_ms"].(float64); ok {
|
||||
timeoutMS = int(v)
|
||||
}
|
||||
if timeoutMS < 1 || timeoutMS > 60000 {
|
||||
return 0, ErrorResult("timeout_ms must be between 1 and 60000")
|
||||
}
|
||||
return time.Duration(timeoutMS) * time.Millisecond, nil
|
||||
}
|
||||
|
||||
func parseSerialWritePayload(args map[string]any) ([]byte, *ToolResult) {
|
||||
if text, ok := args["text"].(string); ok && text != "" {
|
||||
if !utf8.ValidString(text) {
|
||||
return nil, ErrorResult("text must be valid UTF-8")
|
||||
}
|
||||
if len(text) > maxSerialPayloadBytes {
|
||||
return nil, ErrorResult(fmt.Sprintf("text payload too large: maximum %d bytes", maxSerialPayloadBytes))
|
||||
}
|
||||
return []byte(text), nil
|
||||
}
|
||||
|
||||
dataRaw, ok := args["data"].([]any)
|
||||
if !ok || len(dataRaw) == 0 {
|
||||
return nil, ErrorResult("write requires either text or data")
|
||||
}
|
||||
if len(dataRaw) > maxSerialPayloadBytes {
|
||||
return nil, ErrorResult(fmt.Sprintf("data too long: maximum %d bytes", maxSerialPayloadBytes))
|
||||
}
|
||||
|
||||
data := make([]byte, len(dataRaw))
|
||||
for i, v := range dataRaw {
|
||||
f, ok := v.(float64)
|
||||
if !ok {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i))
|
||||
}
|
||||
if f != math.Trunc(f) {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] is not an integer byte value", i))
|
||||
}
|
||||
b := int(f)
|
||||
if b < 0 || b > 255 {
|
||||
return nil, ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b))
|
||||
}
|
||||
data[i] = byte(b)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func formatSerialPayload(action string, cfg serialConfig, data []byte, timeout time.Duration) string {
|
||||
result, _ := json.MarshalIndent(map[string]any{
|
||||
"action": action,
|
||||
"port": cfg.Port,
|
||||
"baud": cfg.Baud,
|
||||
"data_bits": cfg.DataBits,
|
||||
"parity": cfg.Parity,
|
||||
"stop_bits": cfg.StopBits,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
"payload": serialPayloadSummary(data),
|
||||
}, "", " ")
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func serialPayloadSummary(data []byte) map[string]any {
|
||||
hexValues := make([]string, len(data))
|
||||
intValues := make([]int, len(data))
|
||||
for i, b := range data {
|
||||
hexValues[i] = fmt.Sprintf("0x%02x", b)
|
||||
intValues[i] = int(b)
|
||||
}
|
||||
|
||||
summary := map[string]any{
|
||||
"length": len(data),
|
||||
"bytes": intValues,
|
||||
"hex": hexValues,
|
||||
}
|
||||
if utf8.Valid(data) {
|
||||
summary["text"] = string(data)
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func normalizeSerialPort(port string) (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return normalizeWindowsSerialPath(port)
|
||||
case "linux", "darwin":
|
||||
return normalizeUnixSerialPath(port)
|
||||
default:
|
||||
if normalized, err := normalizeUnixSerialPath(port); err == nil {
|
||||
return normalized, nil
|
||||
}
|
||||
return normalizeWindowsSerialPath(port)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeUnixSerialPath(port string) (string, error) {
|
||||
trimmed := strings.TrimSpace(port)
|
||||
if !unixSerialPortPattern.MatchString(trimmed) {
|
||||
return "", fmt.Errorf(
|
||||
"invalid serial port: expected a safe Unix device name such as /dev/ttyUSB0 or /dev/cu.usbserial-0001",
|
||||
)
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "/dev/") {
|
||||
return trimmed, nil
|
||||
}
|
||||
return "/dev/" + trimmed, nil
|
||||
}
|
||||
|
||||
func normalizeWindowsSerialPath(port string) (string, error) {
|
||||
trimmed := strings.ToUpper(strings.TrimSpace(port))
|
||||
if !windowsSerialPortPattern.MatchString(trimmed) {
|
||||
return "", fmt.Errorf("invalid serial port: expected a COM port such as COM3")
|
||||
}
|
||||
if strings.HasPrefix(trimmed, `\\.\`) {
|
||||
return trimmed, nil
|
||||
}
|
||||
return `\\.\` + trimmed, nil
|
||||
}
|
||||
|
||||
func validateSerialBaud(baud int) error {
|
||||
if baud < 50 || baud > 4000000 {
|
||||
return fmt.Errorf("baud must be between 50 and 4000000")
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "darwin":
|
||||
if _, ok := unixSerialBaudRates[baud]; !ok {
|
||||
return fmt.Errorf("unsupported baud rate on this platform: %d (supported up to 230400)", baud)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serialContextErr(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func serialWriteAll(
|
||||
ctx context.Context,
|
||||
data []byte,
|
||||
timeout time.Duration,
|
||||
now func() time.Time,
|
||||
write func([]byte) (int, error),
|
||||
) (int, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total := 0
|
||||
deadline := now().Add(timeout)
|
||||
for total < len(data) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return total, err
|
||||
}
|
||||
if deadline.Sub(now()) <= 0 {
|
||||
return total, fmt.Errorf("timeout while writing serial data")
|
||||
}
|
||||
|
||||
n, err := write(data[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
//go:build darwin
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
func serialGetTermios(fd int) (*unix.Termios, error) {
|
||||
return unix.IoctlGetTermios(fd, unix.TIOCGETA)
|
||||
}
|
||||
|
||||
func serialSetSpeed(tio *unix.Termios, speed uint32) error {
|
||||
tio.Ispeed = uint64(speed)
|
||||
tio.Ospeed = uint64(speed)
|
||||
return nil
|
||||
}
|
||||
|
||||
func serialSetTermios(fd int, tio *unix.Termios) error {
|
||||
return unix.IoctlSetTermios(fd, unix.TIOCSETA, tio)
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
//go:build linux
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
func serialGetTermios(fd int) (*unix.Termios, error) {
|
||||
return unix.IoctlGetTermios(fd, unix.TCGETS)
|
||||
}
|
||||
|
||||
func serialSetSpeed(tio *unix.Termios, speed uint32) error {
|
||||
tio.Ispeed = speed
|
||||
tio.Ospeed = speed
|
||||
return nil
|
||||
}
|
||||
|
||||
func serialSetTermios(fd int, tio *unix.Termios) error {
|
||||
return unix.IoctlSetTermios(fd, unix.TCSETS, tio)
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
//go:build !linux && !darwin && !windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func serialListPorts() ([]serialPortInfo, error) {
|
||||
return nil, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
return nil, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
return 0, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
//go:build !linux && !darwin && !windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSerialListPortsUnsupportedPlatform(t *testing.T) {
|
||||
_, err := serialListPorts()
|
||||
if err == nil {
|
||||
t.Fatal("expected unsupported platform error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not supported") {
|
||||
t.Fatalf("serialListPorts() error = %v, want unsupported platform message", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseSerialConfig(t *testing.T) {
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
cfg, errResult := parseSerialConfig(map[string]any{
|
||||
"port": port,
|
||||
"baud": float64(9600),
|
||||
"data_bits": float64(7),
|
||||
"parity": "even",
|
||||
"stop_bits": float64(2),
|
||||
})
|
||||
if errResult != nil {
|
||||
t.Fatalf("parseSerialConfig() unexpected error = %v", errResult.ForLLM)
|
||||
}
|
||||
|
||||
wantPort := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
wantPort = `\\.\COM3`
|
||||
}
|
||||
if cfg.Port != wantPort || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 {
|
||||
t.Fatalf("parseSerialConfig() = %#v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsInvalidParity(t *testing.T) {
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": port,
|
||||
"parity": "mark",
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected invalid parity to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsUnsupportedUnixBaud(t *testing.T) {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
|
||||
t.Skip("Unix baud validation only applies on Unix platforms")
|
||||
}
|
||||
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": "/dev/ttyUSB0",
|
||||
"baud": float64(460800),
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected unsupported Unix baud rate to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialWritePayloadRejectsFractionalBytes(t *testing.T) {
|
||||
_, errResult := parseSerialWritePayload(map[string]any{
|
||||
"data": []any{65.9},
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected fractional byte value to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSerialBaud(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baud int
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "default-supported", baud: 115200},
|
||||
{name: "max-unix-supported", baud: 230400},
|
||||
{name: "too-low", baud: 49, wantErr: true},
|
||||
{name: "too-high", baud: 4000001, wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateSerialBaud(tt.baud)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("validateSerialBaud(%d) error = %v, wantErr %v", tt.baud, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialReadCanceledBeforeOpen(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, err := serialRead(
|
||||
ctx,
|
||||
serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1},
|
||||
1,
|
||||
time.Second,
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
|
||||
t.Fatalf("serialRead() error = %v, want context canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteCanceledBeforeOpen(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, err := serialWrite(
|
||||
ctx,
|
||||
serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1},
|
||||
[]byte("AT"),
|
||||
time.Second,
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
|
||||
t.Fatalf("serialWrite() error = %v, want context canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsUnsafePortPaths(t *testing.T) {
|
||||
tests := []string{
|
||||
"../../../etc/passwd",
|
||||
"/etc/passwd",
|
||||
`C:\temp\device.txt`,
|
||||
`\\.\C:\temp\device.txt`,
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) {
|
||||
_, errResult := parseSerialConfig(map[string]any{
|
||||
"port": port,
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatalf("expected unsafe port %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUnixSerialPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
port string
|
||||
want string
|
||||
}{
|
||||
{port: "ttyUSB0", want: "/dev/ttyUSB0"},
|
||||
{port: "/dev/ttyACM0", want: "/dev/ttyACM0"},
|
||||
{port: "/dev/cu.usbserial-0001", want: "/dev/cu.usbserial-0001"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := normalizeUnixSerialPath(tt.port)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeUnixSerialPath(%q) unexpected error = %v", tt.port, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("normalizeUnixSerialPath(%q) = %q, want %q", tt.port, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUnixSerialPathRejectsInvalidNames(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"ttyUSB0/../../passwd",
|
||||
"/dev/../../etc/passwd",
|
||||
"/tmp/ttyUSB0",
|
||||
"ttyUSB",
|
||||
"COM3",
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) {
|
||||
if _, err := normalizeUnixSerialPath(port); err == nil {
|
||||
t.Fatalf("expected %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeWindowsSerialPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
port string
|
||||
want string
|
||||
}{
|
||||
{port: "COM3", want: `\\.\COM3`},
|
||||
{port: "com12", want: `\\.\COM12`},
|
||||
{port: `\\.\COM7`, want: `\\.\COM7`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := normalizeWindowsSerialPath(tt.port)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeWindowsSerialPath(%q) unexpected error = %v", tt.port, err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("normalizeWindowsSerialPath(%q) = %q, want %q", tt.port, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeWindowsSerialPathRejectsInvalidNames(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"COM0",
|
||||
"COM",
|
||||
"/dev/ttyUSB0",
|
||||
`C:\temp\device.txt`,
|
||||
`\\.\C:\temp\device.txt`,
|
||||
`\\server\share\COM3`,
|
||||
}
|
||||
|
||||
for _, port := range tests {
|
||||
t.Run(strings.ReplaceAll(strings.ReplaceAll(port, `\`, "_"), "/", "_"), func(t *testing.T) {
|
||||
if _, err := normalizeWindowsSerialPath(port); err == nil {
|
||||
t.Fatalf("expected %q to be rejected", port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialTimeout(t *testing.T) {
|
||||
timeout, errResult := parseSerialTimeout(map[string]any{
|
||||
"timeout_ms": float64(2500),
|
||||
})
|
||||
if errResult != nil {
|
||||
t.Fatalf("parseSerialTimeout() unexpected error = %v", errResult.ForLLM)
|
||||
}
|
||||
if timeout != 2500*time.Millisecond {
|
||||
t.Fatalf("timeout = %v, want 2500ms", timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialWritePayloadSupportsText(t *testing.T) {
|
||||
data, errResult := parseSerialWritePayload(map[string]any{
|
||||
"text": "AT\r\n",
|
||||
})
|
||||
if errResult != nil {
|
||||
t.Fatalf("parseSerialWritePayload() unexpected error = %v", errResult.ForLLM)
|
||||
}
|
||||
if string(data) != "AT\r\n" {
|
||||
t.Fatalf("payload = %q, want %q", string(data), "AT\r\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialWritePayloadRejectsOutOfRangeByte(t *testing.T) {
|
||||
_, errResult := parseSerialWritePayload(map[string]any{
|
||||
"data": []any{float64(256)},
|
||||
})
|
||||
if errResult == nil {
|
||||
t.Fatal("expected payload validation failure")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
unixSerialNow = time.Now
|
||||
unixSerialOpenPort = openAndConfigureSerialPort
|
||||
unixSerialClosePort = unix.Close
|
||||
unixSerialPollRead = pollRead
|
||||
unixSerialPollWrite = pollWrite
|
||||
)
|
||||
|
||||
func serialListPorts() ([]serialPortInfo, error) {
|
||||
patterns := []string{
|
||||
"/dev/ttyS*",
|
||||
"/dev/ttyUSB*",
|
||||
"/dev/ttyACM*",
|
||||
"/dev/ttyAMA*",
|
||||
"/dev/rfcomm*",
|
||||
"/dev/tty.*",
|
||||
"/dev/cu.*",
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
ports := make([]serialPortInfo, 0)
|
||||
for _, pattern := range patterns {
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, match := range matches {
|
||||
if _, ok := seen[match]; ok {
|
||||
continue
|
||||
}
|
||||
info, err := os.Stat(match)
|
||||
if err != nil || info.IsDir() {
|
||||
continue
|
||||
}
|
||||
seen[match] = struct{}{}
|
||||
ports = append(ports, serialPortInfo{
|
||||
Name: filepath.Base(match),
|
||||
Path: match,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(ports, func(i, j int) bool {
|
||||
return ports[i].Path < ports[j].Path
|
||||
})
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fd, err := unixSerialOpenPort(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer unixSerialClosePort(fd)
|
||||
|
||||
buf := make([]byte, length)
|
||||
total := 0
|
||||
deadline := unixSerialNow().Add(timeout)
|
||||
|
||||
for total < length {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remaining := deadline.Sub(unixSerialNow())
|
||||
if remaining <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := unixSerialPollRead(fd, buf[total:], minSerialPollTimeout(remaining))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
return buf[:total], nil
|
||||
}
|
||||
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
fd, err := unixSerialOpenPort(cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer unixSerialClosePort(fd)
|
||||
|
||||
total := 0
|
||||
deadline := unixSerialNow().Add(timeout)
|
||||
for total < len(data) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
remaining := deadline.Sub(unixSerialNow())
|
||||
if remaining <= 0 {
|
||||
return total, fmt.Errorf("timeout while writing serial data")
|
||||
}
|
||||
|
||||
n, err := unixSerialPollWrite(fd, data[total:], minSerialPollTimeout(remaining))
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func openAndConfigureSerialPort(cfg serialConfig) (int, error) {
|
||||
fd, err := unix.Open(cfg.Port, unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if err := unix.SetNonblock(fd, false); err != nil {
|
||||
unix.Close(fd)
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if err := configureUnixSerialPort(fd, cfg); err != nil {
|
||||
unix.Close(fd)
|
||||
return -1, err
|
||||
}
|
||||
|
||||
return fd, nil
|
||||
}
|
||||
|
||||
func configureUnixSerialPort(fd int, cfg serialConfig) error {
|
||||
tio, err := serialGetTermios(fd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tio.Iflag = 0
|
||||
tio.Oflag = 0
|
||||
tio.Lflag = 0
|
||||
tio.Cflag = unix.CREAD | unix.CLOCAL
|
||||
tio.Cc[unix.VMIN] = 0
|
||||
tio.Cc[unix.VTIME] = 0
|
||||
|
||||
switch cfg.DataBits {
|
||||
case 5:
|
||||
tio.Cflag |= unix.CS5
|
||||
case 6:
|
||||
tio.Cflag |= unix.CS6
|
||||
case 7:
|
||||
tio.Cflag |= unix.CS7
|
||||
default:
|
||||
tio.Cflag |= unix.CS8
|
||||
}
|
||||
|
||||
switch cfg.Parity {
|
||||
case "even":
|
||||
tio.Cflag |= unix.PARENB
|
||||
case "odd":
|
||||
tio.Cflag |= unix.PARENB | unix.PARODD
|
||||
}
|
||||
|
||||
if cfg.StopBits == 2 {
|
||||
tio.Cflag |= unix.CSTOPB
|
||||
}
|
||||
|
||||
speed, err := serialBaudToUnix(cfg.Baud)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := serialSetSpeed(tio, speed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return serialSetTermios(fd, tio)
|
||||
}
|
||||
|
||||
func serialBaudToUnix(baud int) (uint32, error) {
|
||||
switch baud {
|
||||
case 50:
|
||||
return unix.B50, nil
|
||||
case 75:
|
||||
return unix.B75, nil
|
||||
case 110:
|
||||
return unix.B110, nil
|
||||
case 134:
|
||||
return unix.B134, nil
|
||||
case 150:
|
||||
return unix.B150, nil
|
||||
case 200:
|
||||
return unix.B200, nil
|
||||
case 300:
|
||||
return unix.B300, nil
|
||||
case 600:
|
||||
return unix.B600, nil
|
||||
case 1200:
|
||||
return unix.B1200, nil
|
||||
case 1800:
|
||||
return unix.B1800, nil
|
||||
case 2400:
|
||||
return unix.B2400, nil
|
||||
case 4800:
|
||||
return unix.B4800, nil
|
||||
case 9600:
|
||||
return unix.B9600, nil
|
||||
case 19200:
|
||||
return unix.B19200, nil
|
||||
case 38400:
|
||||
return unix.B38400, nil
|
||||
case 57600:
|
||||
return unix.B57600, nil
|
||||
case 115200:
|
||||
return unix.B115200, nil
|
||||
case 230400:
|
||||
return unix.B230400, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported baud rate on this platform: %d", baud)
|
||||
}
|
||||
}
|
||||
|
||||
func pollRead(fd int, dst []byte, timeout time.Duration) (int, error) {
|
||||
pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}}
|
||||
n, err := unix.Poll(pfd, durationToPollTimeout(timeout))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return unix.Read(fd, dst)
|
||||
}
|
||||
|
||||
func pollWrite(fd int, src []byte, timeout time.Duration) (int, error) {
|
||||
pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLOUT}}
|
||||
n, err := unix.Poll(pfd, durationToPollTimeout(timeout))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return unix.Write(fd, src)
|
||||
}
|
||||
|
||||
func durationToPollTimeout(timeout time.Duration) int {
|
||||
if timeout <= 0 {
|
||||
return 0
|
||||
}
|
||||
ms := int(timeout / time.Millisecond)
|
||||
if ms == 0 {
|
||||
return 1
|
||||
}
|
||||
return ms
|
||||
}
|
||||
|
||||
func minSerialPollTimeout(timeout time.Duration) time.Duration {
|
||||
if timeout > serialPollInterval {
|
||||
return serialPollInterval
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func stubUnixSerialIO(t *testing.T, now *time.Time) {
|
||||
t.Helper()
|
||||
|
||||
prevNow := unixSerialNow
|
||||
prevOpen := unixSerialOpenPort
|
||||
prevClose := unixSerialClosePort
|
||||
prevPollRead := unixSerialPollRead
|
||||
prevPollWrite := unixSerialPollWrite
|
||||
|
||||
unixSerialNow = func() time.Time {
|
||||
return *now
|
||||
}
|
||||
unixSerialOpenPort = func(cfg serialConfig) (int, error) {
|
||||
return 42, nil
|
||||
}
|
||||
unixSerialClosePort = func(fd int) error {
|
||||
return nil
|
||||
}
|
||||
unixSerialPollRead = prevPollRead
|
||||
unixSerialPollWrite = prevPollWrite
|
||||
|
||||
t.Cleanup(func() {
|
||||
unixSerialNow = prevNow
|
||||
unixSerialOpenPort = prevOpen
|
||||
unixSerialClosePort = prevClose
|
||||
unixSerialPollRead = prevPollRead
|
||||
unixSerialPollWrite = prevPollWrite
|
||||
})
|
||||
}
|
||||
|
||||
func TestSerialReadWaitsPastEmptyPollsUntilDeadline(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
pollCalls := 0
|
||||
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
if timeout > serialPollInterval {
|
||||
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
|
||||
}
|
||||
now = now.Add(timeout)
|
||||
if pollCalls < 4 {
|
||||
return 0, nil
|
||||
}
|
||||
return copy(dst, []byte("OK")), nil
|
||||
}
|
||||
|
||||
got, err := serialRead(context.Background(), serialConfig{}, 2, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("serialRead() error = %v", err)
|
||||
}
|
||||
if string(got) != "OK" {
|
||||
t.Fatalf("serialRead() = %q, want %q", got, "OK")
|
||||
}
|
||||
if pollCalls != 4 {
|
||||
t.Fatalf("poll calls = %d, want 4", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialReadReturnsPromptlyOnContextCancelBetweenPolls(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pollCalls := 0
|
||||
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
now = now.Add(timeout)
|
||||
cancel()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
_, err := serialRead(ctx, serialConfig{}, 1, time.Second)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("serialRead() error = %v, want context canceled", err)
|
||||
}
|
||||
if pollCalls != 1 {
|
||||
t.Fatalf("poll calls = %d, want 1", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteWaitsPastEmptyPollsUntilReady(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
pollCalls := 0
|
||||
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
if timeout > serialPollInterval {
|
||||
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
|
||||
}
|
||||
now = now.Add(timeout)
|
||||
switch pollCalls {
|
||||
case 1, 2:
|
||||
return 0, nil
|
||||
default:
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
written, err := serialWrite(context.Background(), serialConfig{}, []byte("OK"), 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("serialWrite() error = %v", err)
|
||||
}
|
||||
if written != 2 {
|
||||
t.Fatalf("serialWrite() wrote %d bytes, want 2", written)
|
||||
}
|
||||
if pollCalls != 4 {
|
||||
t.Fatalf("poll calls = %d, want 4", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteTimesOutAfterRepeatedEmptyPolls(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
|
||||
now = now.Add(timeout)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
written, err := serialWrite(context.Background(), serialConfig{}, []byte("A"), 250*time.Millisecond)
|
||||
if err == nil || err.Error() != "timeout while writing serial data" {
|
||||
t.Fatalf("serialWrite() error = %v, want timeout", err)
|
||||
}
|
||||
if written != 0 {
|
||||
t.Fatalf("serialWrite() wrote %d bytes, want 0", written)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
//go:build windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
procGetCommState = kernel32.NewProc("GetCommState")
|
||||
procSetCommState = kernel32.NewProc("SetCommState")
|
||||
procSetCommTimeouts = kernel32.NewProc("SetCommTimeouts")
|
||||
procPurgeComm = kernel32.NewProc("PurgeComm")
|
||||
)
|
||||
|
||||
const (
|
||||
purgeTxClear = 0x0004
|
||||
purgeRxClear = 0x0008
|
||||
|
||||
dcbFlagBinary = 0x00000001
|
||||
dcbFlagParity = 0x00000002
|
||||
dcbFlagOutxCtsFlow = 0x00000004
|
||||
dcbFlagOutxDsrFlow = 0x00000008
|
||||
dcbFlagDtrControlMask = 0x00000030
|
||||
dcbFlagDsrSensitivity = 0x00000040
|
||||
dcbFlagTXContinueOnXoff = 0x00000080
|
||||
dcbFlagOutX = 0x00000100
|
||||
dcbFlagInX = 0x00000200
|
||||
dcbFlagRtsControlMask = 0x00003000
|
||||
)
|
||||
|
||||
type dcb struct {
|
||||
DCBlength uint32
|
||||
BaudRate uint32
|
||||
Flags uint32
|
||||
Reserved uint16
|
||||
XonLim uint16
|
||||
XoffLim uint16
|
||||
ByteSize byte
|
||||
Parity byte
|
||||
StopBits byte
|
||||
XonChar byte
|
||||
XoffChar byte
|
||||
ErrorChar byte
|
||||
EofChar byte
|
||||
EvtChar byte
|
||||
wReserved1 uint16
|
||||
}
|
||||
|
||||
type commTimeouts struct {
|
||||
ReadIntervalTimeout uint32
|
||||
ReadTotalTimeoutMultiplier uint32
|
||||
ReadTotalTimeoutConstant uint32
|
||||
WriteTotalTimeoutMultiplier uint32
|
||||
WriteTotalTimeoutConstant uint32
|
||||
}
|
||||
|
||||
func serialListPorts() ([]serialPortInfo, error) {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, `HARDWARE\DEVICEMAP\SERIALCOMM`, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
names, err := key.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ports := make([]serialPortInfo, 0, len(names))
|
||||
seen := make(map[string]struct{})
|
||||
for _, name := range names {
|
||||
value, _, err := key.GetStringValue(name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
portName := strings.TrimSpace(value)
|
||||
if portName == "" {
|
||||
continue
|
||||
}
|
||||
normalized := strings.ToUpper(portName)
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
ports = append(ports, serialPortInfo{
|
||||
Name: normalized,
|
||||
Path: normalized,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(ports, func(i, j int) bool {
|
||||
return ports[i].Path < ports[j].Path
|
||||
})
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer windows.CloseHandle(handle)
|
||||
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := make([]byte, length)
|
||||
var read uint32
|
||||
// Synchronous serial I/O on Windows cannot be interrupted once the syscall starts.
|
||||
// COMMTIMEOUTS bounds how long turn cancellation may take to surface.
|
||||
if err := windows.ReadFile(handle, buf, &read, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf[:read], nil
|
||||
}
|
||||
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer windows.CloseHandle(handle)
|
||||
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return serialWriteAll(ctx, data, timeout, time.Now, func(chunk []byte) (int, error) {
|
||||
var written uint32
|
||||
// Like ReadFile above, this synchronous WriteFile call relies on COMMTIMEOUTS
|
||||
// rather than context preemption once the syscall is in flight.
|
||||
if err := windows.WriteFile(handle, chunk, &written, nil); err != nil {
|
||||
return int(written), err
|
||||
}
|
||||
return int(written), nil
|
||||
})
|
||||
}
|
||||
|
||||
func openAndConfigureWindowsSerial(cfg serialConfig, timeout time.Duration) (windows.Handle, error) {
|
||||
handle, err := windows.CreateFile(
|
||||
windows.StringToUTF16Ptr(cfg.Port),
|
||||
windows.GENERIC_READ|windows.GENERIC_WRITE,
|
||||
0,
|
||||
nil,
|
||||
windows.OPEN_EXISTING,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := configureWindowsSerialPort(handle, cfg, timeout); err != nil {
|
||||
windows.CloseHandle(handle)
|
||||
return 0, err
|
||||
}
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout time.Duration) error {
|
||||
state := dcb{DCBlength: uint32(unsafe.Sizeof(dcb{}))}
|
||||
r1, _, err := procGetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state)))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
state.BaudRate = uint32(cfg.Baud)
|
||||
state.ByteSize = byte(cfg.DataBits)
|
||||
state.Flags = sanitizeWindowsSerialFlags(state.Flags)
|
||||
state.Flags |= dcbFlagBinary
|
||||
|
||||
switch cfg.Parity {
|
||||
case "even":
|
||||
state.Parity = 2
|
||||
state.Flags |= dcbFlagParity
|
||||
case "odd":
|
||||
state.Parity = 1
|
||||
state.Flags |= dcbFlagParity
|
||||
default:
|
||||
state.Parity = 0
|
||||
state.Flags &^= dcbFlagParity
|
||||
}
|
||||
|
||||
switch cfg.StopBits {
|
||||
case 2:
|
||||
state.StopBits = 2
|
||||
default:
|
||||
state.StopBits = 0
|
||||
}
|
||||
|
||||
r1, _, err = procSetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state)))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
timeoutMS := uint32(timeout / time.Millisecond)
|
||||
if timeoutMS == 0 {
|
||||
timeoutMS = 1
|
||||
}
|
||||
timeouts := commTimeouts{
|
||||
ReadIntervalTimeout: timeoutMS,
|
||||
ReadTotalTimeoutConstant: timeoutMS,
|
||||
WriteTotalTimeoutConstant: timeoutMS,
|
||||
ReadTotalTimeoutMultiplier: 0,
|
||||
WriteTotalTimeoutMultiplier: 0,
|
||||
}
|
||||
r1, _, err = procSetCommTimeouts.Call(uintptr(handle), uintptr(unsafe.Pointer(&timeouts)))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
procPurgeComm.Call(uintptr(handle), uintptr(purgeRxClear|purgeTxClear))
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitizeWindowsSerialFlags(flags uint32) uint32 {
|
||||
flags &^= dcbFlagOutxCtsFlow |
|
||||
dcbFlagOutxDsrFlow |
|
||||
dcbFlagDtrControlMask |
|
||||
dcbFlagDsrSensitivity |
|
||||
dcbFlagTXContinueOnXoff |
|
||||
dcbFlagOutX |
|
||||
dcbFlagInX |
|
||||
dcbFlagRtsControlMask
|
||||
return flags
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
//go:build windows
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeWindowsSerialFlags(t *testing.T) {
|
||||
flags := uint32(
|
||||
dcbFlagBinary |
|
||||
dcbFlagParity |
|
||||
dcbFlagOutxCtsFlow |
|
||||
dcbFlagOutxDsrFlow |
|
||||
dcbFlagDtrControlMask |
|
||||
dcbFlagDsrSensitivity |
|
||||
dcbFlagTXContinueOnXoff |
|
||||
dcbFlagOutX |
|
||||
dcbFlagInX |
|
||||
dcbFlagRtsControlMask,
|
||||
)
|
||||
|
||||
got := sanitizeWindowsSerialFlags(flags)
|
||||
|
||||
if got&dcbFlagBinary == 0 {
|
||||
t.Fatal("sanitizeWindowsSerialFlags() should preserve fBinary")
|
||||
}
|
||||
if got&dcbFlagParity == 0 {
|
||||
t.Fatal("sanitizeWindowsSerialFlags() should preserve fParity")
|
||||
}
|
||||
if got&(dcbFlagOutxCtsFlow|
|
||||
dcbFlagOutxDsrFlow|
|
||||
dcbFlagDtrControlMask|
|
||||
dcbFlagDsrSensitivity|
|
||||
dcbFlagTXContinueOnXoff|
|
||||
dcbFlagOutX|
|
||||
dcbFlagInX|
|
||||
dcbFlagRtsControlMask) != 0 {
|
||||
t.Fatalf("sanitizeWindowsSerialFlags() = %#x, want flow-control bits cleared", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSerialWriteAllRetriesPartialWritesUntilComplete(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
calls := 0
|
||||
|
||||
written, err := serialWriteAll(context.Background(), []byte("PING"), time.Second, func() time.Time {
|
||||
return now
|
||||
}, func(chunk []byte) (int, error) {
|
||||
calls++
|
||||
now = now.Add(100 * time.Millisecond)
|
||||
switch calls {
|
||||
case 1:
|
||||
if string(chunk) != "PING" {
|
||||
t.Fatalf("first chunk = %q, want %q", chunk, "PING")
|
||||
}
|
||||
return 2, nil
|
||||
case 2:
|
||||
if string(chunk) != "NG" {
|
||||
t.Fatalf("second chunk = %q, want %q", chunk, "NG")
|
||||
}
|
||||
return 2, nil
|
||||
default:
|
||||
t.Fatalf("unexpected extra write call %d", calls)
|
||||
return 0, nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("serialWriteAll() error = %v", err)
|
||||
}
|
||||
if written != 4 {
|
||||
t.Fatalf("serialWriteAll() wrote %d bytes, want 4", written)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteAllTimesOutAfterZeroByteWrites(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
calls := 0
|
||||
|
||||
written, err := serialWriteAll(context.Background(), []byte("A"), 250*time.Millisecond, func() time.Time {
|
||||
return now
|
||||
}, func(chunk []byte) (int, error) {
|
||||
calls++
|
||||
now = now.Add(100 * time.Millisecond)
|
||||
return 0, nil
|
||||
})
|
||||
if err == nil || err.Error() != "timeout while writing serial data" {
|
||||
t.Fatalf("serialWriteAll() error = %v, want timeout", err)
|
||||
}
|
||||
if written != 0 {
|
||||
t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written)
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Fatalf("write calls = %d, want 3", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteAllReturnsContextCancellationAfterRetryBoundary(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
calls := 0
|
||||
|
||||
written, err := serialWriteAll(ctx, []byte("A"), time.Second, func() time.Time {
|
||||
return now
|
||||
}, func(chunk []byte) (int, error) {
|
||||
calls++
|
||||
now = now.Add(100 * time.Millisecond)
|
||||
cancel()
|
||||
return 0, nil
|
||||
})
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("serialWriteAll() error = %v, want context canceled", err)
|
||||
}
|
||||
if written != 0 {
|
||||
t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("write calls = %d, want 1", calls)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,9 @@ package tools
|
||||
import hardwaretools "github.com/sipeed/picoclaw/pkg/tools/hardware"
|
||||
|
||||
type (
|
||||
I2CTool = hardwaretools.I2CTool
|
||||
SPITool = hardwaretools.SPITool
|
||||
I2CTool = hardwaretools.I2CTool
|
||||
SerialTool = hardwaretools.SerialTool
|
||||
SPITool = hardwaretools.SPITool
|
||||
)
|
||||
|
||||
func NewI2CTool() *I2CTool {
|
||||
@@ -14,3 +15,7 @@ func NewI2CTool() *I2CTool {
|
||||
func NewSPITool() *SPITool {
|
||||
return hardwaretools.NewSPITool()
|
||||
}
|
||||
|
||||
func NewSerialTool() *SerialTool {
|
||||
return hardwaretools.NewSerialTool()
|
||||
}
|
||||
|
||||
@@ -171,6 +171,12 @@ var toolCatalog = []toolCatalogEntry{
|
||||
Category: "hardware",
|
||||
ConfigKey: "spi",
|
||||
},
|
||||
{
|
||||
Name: "serial",
|
||||
Description: "Interact with serial ports exposed on the host.",
|
||||
Category: "hardware",
|
||||
ConfigKey: "serial",
|
||||
},
|
||||
{
|
||||
Name: "tool_search_tool_regex",
|
||||
Description: "Discover hidden MCP tools by regex search when tool discovery is enabled.",
|
||||
@@ -265,6 +271,8 @@ func buildToolSupport(cfg *config.Config) []toolSupportItem {
|
||||
status, reasonCode = resolveWebSearchToolSupport(cfg)
|
||||
case "i2c", "spi":
|
||||
status, reasonCode = resolveHardwareToolSupport(cfg.Tools.IsToolEnabled(entry.ConfigKey))
|
||||
case "serial":
|
||||
status, reasonCode = resolveSerialToolSupport(cfg.Tools.IsToolEnabled(entry.ConfigKey))
|
||||
default:
|
||||
if cfg.Tools.IsToolEnabled(entry.ConfigKey) {
|
||||
status = "enabled"
|
||||
@@ -293,6 +301,18 @@ func resolveHardwareToolSupport(enabled bool) (string, string) {
|
||||
return "enabled", ""
|
||||
}
|
||||
|
||||
func resolveSerialToolSupport(enabled bool) (string, string) {
|
||||
if !enabled {
|
||||
return "disabled", ""
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "linux", "darwin", "windows":
|
||||
return "enabled", ""
|
||||
default:
|
||||
return "blocked", "requires_serial_platform"
|
||||
}
|
||||
}
|
||||
|
||||
func resolveDiscoveryToolSupport(cfg *config.Config, methodEnabled bool) (string, string) {
|
||||
if !cfg.Tools.IsToolEnabled("mcp") {
|
||||
return "disabled", ""
|
||||
@@ -362,6 +382,8 @@ func applyToolState(cfg *config.Config, toolName string, enabled bool) error {
|
||||
cfg.Tools.I2C.Enabled = enabled
|
||||
case "spi":
|
||||
cfg.Tools.SPI.Enabled = enabled
|
||||
case "serial":
|
||||
cfg.Tools.Serial.Enabled = enabled
|
||||
case "tool_search_tool_regex":
|
||||
cfg.Tools.MCP.Discovery.UseRegex = enabled
|
||||
if enabled {
|
||||
|
||||
@@ -92,9 +92,36 @@ func TestHandleListTools(t *testing.T) {
|
||||
if gotTools["i2c"].Status != "disabled" {
|
||||
t.Fatalf("i2c status = %q, want disabled on linux when config is off", gotTools["i2c"].Status)
|
||||
}
|
||||
if gotTools["serial"].Status != "disabled" {
|
||||
t.Fatalf("serial status = %q, want disabled when config is off", gotTools["serial"].Status)
|
||||
}
|
||||
|
||||
cfg.Tools.Serial.Enabled = true
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/tools", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
gotTools = make(map[string]toolSupportItem, len(resp.Tools))
|
||||
for _, tool := range resp.Tools {
|
||||
gotTools[tool.Name] = tool
|
||||
}
|
||||
if gotTools["serial"].Status != "enabled" {
|
||||
t.Fatalf("serial = %#v, want enabled on linux when config is on", gotTools["serial"])
|
||||
}
|
||||
} else {
|
||||
cfg.Tools.I2C.Enabled = true
|
||||
cfg.Tools.SPI.Enabled = true
|
||||
cfg.Tools.Serial.Enabled = true
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
@@ -120,6 +147,16 @@ func TestHandleListTools(t *testing.T) {
|
||||
if gotTools["spi"].Status != "blocked" || gotTools["spi"].ReasonCode != "requires_linux" {
|
||||
t.Fatalf("spi = %#v, want blocked/requires_linux", gotTools["spi"])
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "windows":
|
||||
if gotTools["serial"].Status != "enabled" {
|
||||
t.Fatalf("serial = %#v, want enabled on supported host", gotTools["serial"])
|
||||
}
|
||||
default:
|
||||
if gotTools["serial"].Status != "blocked" || gotTools["serial"].ReasonCode != "requires_serial_platform" {
|
||||
t.Fatalf("serial = %#v, want blocked/requires_serial_platform", gotTools["serial"])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,6 +232,26 @@ func TestHandleUpdateToolState(t *testing.T) {
|
||||
if !updated.Tools.Cron.Enabled {
|
||||
t.Fatalf("cron should be enabled: %#v", updated.Tools.Cron)
|
||||
}
|
||||
|
||||
rec4 := httptest.NewRecorder()
|
||||
req4 := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/tools/serial/state",
|
||||
bytes.NewBufferString(`{"enabled":true}`),
|
||||
)
|
||||
req4.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec4, req4)
|
||||
if rec4.Code != http.StatusOK {
|
||||
t.Fatalf("serial status = %d, want %d, body=%s", rec4.Code, http.StatusOK, rec4.Body.String())
|
||||
}
|
||||
|
||||
updated, err = config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig(updated serial) error = %v", err)
|
||||
}
|
||||
if !updated.Tools.Serial.Enabled {
|
||||
t.Fatalf("serial should be enabled: %#v", updated.Tools.Serial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListTools_ReportsWebSearchEnabledWhenToolIsOn(t *testing.T) {
|
||||
|
||||
@@ -603,6 +603,7 @@
|
||||
},
|
||||
"reasons": {
|
||||
"requires_linux": "This tool only works on Linux hosts with the required device files exposed.",
|
||||
"requires_serial_platform": "This tool currently supports Linux, macOS, and Windows hosts with accessible serial ports.",
|
||||
"requires_skills": "Enable `tools.skills` before this skill-registry tool can be used.",
|
||||
"requires_subagent": "Enable `tools.subagent` before the spawn tool can delegate work.",
|
||||
"requires_mcp_discovery": "Enable `tools.mcp.discovery` before MCP discovery tools become available.",
|
||||
|
||||
@@ -603,6 +603,7 @@
|
||||
},
|
||||
"reasons": {
|
||||
"requires_linux": "该工具仅在 Linux 主机上可用,并且需要暴露对应的设备文件。",
|
||||
"requires_serial_platform": "该工具当前支持 Linux、macOS 和 Windows,且要求主机可访问对应串口。",
|
||||
"requires_skills": "需要先启用 `tools.skills`,该技能注册表工具才能使用。",
|
||||
"requires_subagent": "需要先启用 `tools.subagent`,`spawn` 才能委派任务。",
|
||||
"requires_mcp_discovery": "需要先启用 `tools.mcp.discovery`,MCP 发现工具才会可用。",
|
||||
|
||||
Reference in New Issue
Block a user