diff --git a/Makefile b/Makefile index acb258370..0f6a036f4 100644 --- a/Makefile +++ b/Makefile @@ -171,6 +171,18 @@ ifeq ($(OS),Windows_NT) EXT=.exe endif +ifneq ($(strip $(GOOS)),) + PLATFORM:=$(GOOS) +endif + +ifneq ($(strip $(GOARCH)),) + ARCH:=$(GOARCH) +endif + +ifeq ($(PLATFORM),windows) + EXT=.exe +endif + BINARY_PATH=$(BUILD_DIR)/$(BINARY_NAME)-$(PLATFORM)-$(ARCH) # Default target @@ -181,10 +193,11 @@ generate: @echo "Run generate..." ifeq ($(OS),Windows_NT) @$(POWERSHELL) "if (Test-Path -LiteralPath './$(CMD_DIR)/workspace') { Remove-Item -LiteralPath './$(CMD_DIR)/workspace' -Recurse -Force }" + @$(POWERSHELL) "$$env:GOOS=''; $$env:GOARCH=''; $(GO) generate ./..." else @rm -r ./$(CMD_DIR)/workspace 2>/dev/null || true + @GOOS=$$($(GO) env GOHOSTOS) GOARCH=$$($(GO) env GOHOSTARCH) $(GO) generate ./... endif - @$(GO) generate ./... @echo "Run generate complete" ## build: Build the picoclaw binary for current platform @@ -196,7 +209,7 @@ ifeq ($(OS),Windows_NT) @$(POWERSHELL) "Copy-Item -LiteralPath '$(BINARY_PATH)$(EXT)' -Destination '$(BUILD_DIR)/$(BINARY_NAME)$(EXT)' -Force" else @mkdir -p $(BUILD_DIR) - @GOARCH=${ARCH} $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BINARY_PATH)$(EXT) ./$(CMD_DIR) + @GOOS=$(PLATFORM) GOARCH=$(ARCH) $(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BINARY_PATH)$(EXT) ./$(CMD_DIR) @echo "Build complete: $(BINARY_PATH)$(EXT)" @$(LNCMD) $(BINARY_NAME)-$(PLATFORM)-$(ARCH)$(EXT) $(BUILD_DIR)/$(BINARY_NAME)$(EXT) endif @@ -211,7 +224,7 @@ ifeq ($(OS),Windows_NT) @$(POWERSHELL) "Copy-Item -LiteralPath '$(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH)$(EXT)' -Destination '$(BUILD_DIR)/picoclaw-launcher$(EXT)' -Force" else @mkdir -p $(BUILD_DIR) - @GOARCH=${ARCH} $(MAKE) -C web build \ + @GOOS=$(PLATFORM) GOARCH=$(ARCH) $(MAKE) -C web build \ OUTPUT="$(CURDIR)/$(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH)$(EXT)" \ WEB_GO='$(WEB_GO)' \ GO_BUILD_TAGS='$(GO_BUILD_TAGS)' \ diff --git a/cmd/picoclaw/internal/onboard/command.go b/cmd/picoclaw/internal/onboard/command.go index 3f0ff0d8d..bf8f4104f 100644 --- a/cmd/picoclaw/internal/onboard/command.go +++ b/cmd/picoclaw/internal/onboard/command.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" ) -//go:generate go run ../../../../scripts/copydir.go "${DOLLAR}{codespace}/workspace" ./workspace +//go:generate go run ../../../../scripts/copydir.go ../../../../workspace ./workspace //go:embed workspace var embeddedFiles embed.FS diff --git a/config/config.example.json b/config/config.example.json index 30460c231..4205b8e8a 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -437,6 +437,9 @@ "enabled": true, "mode": "bytes" }, + "serial": { + "enabled": false + }, "send_tts": { "enabled": false }, diff --git a/docs/design/current-hardware-support-and-serial.zh.md b/docs/design/current-hardware-support-and-serial.zh.md new file mode 100644 index 000000000..bf91355c2 --- /dev/null +++ b/docs/design/current-hardware-support-and-serial.zh.md @@ -0,0 +1,91 @@ +# 当前硬件支持现状与串口 Tool 方案 + +## 现状结论 + +当前项目已有的硬件相关能力主要分为两条线: + +1. 设备事件监控 + - `pkg/devices` 已实现设备事件服务。 + - 当前只有 Linux USB 热插拔事件源 `pkg/devices/sources/usb_linux.go`。 + - 能力定位是“发现和通知”,不是“总线读写控制”。 + +2. 硬件控制 Tool + - `pkg/tools/hardware/i2c*.go`:I2C Tool,支持 `detect`、`scan`、`read`、`write`。 + - `pkg/tools/hardware/spi*.go`:SPI Tool,支持 `list`、`transfer`、`read`。 + - 这两类 Tool 当前都只在 Linux 主机上启用,直接依赖 `/dev/i2c-*` 与 `/dev/spidev*`。 + +因此,项目在“硬件支持能力”上已经具备: + +- Linux USB 设备插拔感知 +- Linux I2C 总线控制 +- Linux SPI 总线控制 + +但还缺少: + +- 串口/UART 控制 +- macOS / Windows 下可直接使用的硬件控制 Tool +- 面向统一硬件抽象的跨总线能力模型 + +## 本次新增 + +本次新增内建 `serial` Tool,并接入现有 Tool 体系: + +- 配置项:`tools.serial.enabled` +- Tool 注册:`pkg/agent/agent_init.go` +- Web 工具页:`/api/tools` 能展示与切换 `serial` +- 前端状态文案:新增 `requires_serial_platform` + +## Serial Tool 设计 + +`serial` 采用无状态调用模型,每次请求都自行打开和关闭端口,避免在 agent 回合之间维护串口会话状态。 + +支持动作: + +- `list`:枚举主机串口 +- `read`:从串口读取指定长度字节 +- `write`:向串口写入字节或文本 + +公共参数: + +- `port` +- `baud` +- `data_bits` +- `parity` +- `stop_bits` +- `timeout_ms` + +当前波特率实现边界: + +- Windows 允许配置工具层接受的范围 `50-4000000` +- Linux / macOS 当前仅支持标准 termios 波特率,实际支持到 `230400` +- 因此 `baud` 的跨平台可移植取值应优先使用 `230400` 及以下的常见标准速率 + +安全约束: + +- `write` 必须显式传 `confirm: true` +- 单次读写负载限制为 `4096` 字节 +- `port` 只接受白名单串口名: + - Linux / macOS 仅允许 `/dev/tty*`、`/dev/cu.*` 及对应简写设备名 + - Windows 仅允许 `COM\d+` 或 `\\.\COM\d+` + - 明确拒绝 `..`、普通文件绝对路径、盘符路径等非串口设备路径,避免路径穿越或误打开任意文件 + +## 跨平台实现边界 + +- Linux / macOS: + - 基于 `golang.org/x/sys/unix` 和 termios 配置串口参数。 + - 当前仅接入标准 termios 波特率映射,最高到 `230400`,尚未扩展 `460800`、`921600`、`1000000`、`2000000` 等更高速率。 + - 通过 `/dev/...` 枚举和访问设备。 + +- Windows: + - 基于 `kernel32` 串口 API 配置 `DCB` 和 `COMMTIMEOUTS`。 + - 当前读写仍使用同步 `ReadFile` / `WriteFile`;一旦 syscall 已进入执行,turn context cancellation 不能立即打断,只能等待 `COMMTIMEOUTS` 触发后返回。 + - 通过注册表 `HARDWARE\\DEVICEMAP\\SERIALCOMM` 枚举端口。 + +- 其他平台: + - `serial` Tool 显式返回 unsupported,不做静默降级。 + +## 后续建议 + +1. 如果需要持续交互式串口会话,建议再增加 session 型 Tool,而不是让 LLM 反复做短连接轮询。 +2. 如果后续要支持 CAN、GPIO、PWM,建议抽出统一的硬件 capability 描述层,而不是继续只靠 Tool 名称区分。 +3. 若需要生产级稳定性,建议补真实串口回环测试,至少覆盖 Linux PTY 和 Windows COM 模拟场景。 diff --git a/pkg/agent/agent_init.go b/pkg/agent/agent_init.go index 611d634e8..335fd8537 100644 --- a/pkg/agent/agent_init.go +++ b/pkg/agent/agent_init.go @@ -128,6 +128,9 @@ func registerSharedTools( if cfg.Tools.IsToolEnabled("spi") { agent.Tools.Register(tools.NewSPITool()) } + if cfg.Tools.IsToolEnabled("serial") { + agent.Tools.Register(tools.NewSerialTool()) + } // Message tool if cfg.Tools.IsToolEnabled("message") { diff --git a/pkg/config/config.go b/pkg/config/config.go index 16497b4ac..dc9e88949 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -823,6 +823,7 @@ type ToolsConfig struct { ListDir ToolConfig `json:"list_dir" yaml:"-" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"` Message ToolConfig `json:"message" yaml:"-" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"` ReadFile ReadFileToolConfig `json:"read_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"` + Serial ToolConfig `json:"serial" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SERIAL_"` SendFile ToolConfig `json:"send_file" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"` SendTTS ToolConfig `json:"send_tts" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SEND_TTS_"` Spawn ToolConfig `json:"spawn" yaml:"-" envPrefix:"PICOCLAW_TOOLS_SPAWN_"` @@ -1548,6 +1549,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool { return t.Message.Enabled case "read_file": return t.ReadFile.Enabled + case "serial": + return t.Serial.Enabled case "spawn": return t.Spawn.Enabled case "spawn_status": diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index f3aaca7ab..be8c32495 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -435,6 +435,9 @@ func DefaultConfig() *Config { Mode: ReadFileModeBytes, MaxReadFileSize: 64 * 1024, // 64KB }, + Serial: ToolConfig{ + Enabled: false, // Hardware tool - requires host serial ports + }, Spawn: ToolConfig{ Enabled: true, }, diff --git a/pkg/tools/facade_compat_test.go b/pkg/tools/facade_compat_test.go index 672554209..378462512 100644 --- a/pkg/tools/facade_compat_test.go +++ b/pkg/tools/facade_compat_test.go @@ -9,6 +9,9 @@ func TestFacadeConstructorsRemainAvailable(t *testing.T) { if NewSPITool() == nil { t.Fatal("NewSPITool should return a tool") } + if NewSerialTool() == nil { + t.Fatal("NewSerialTool should return a tool") + } if NewMessageTool() == nil { t.Fatal("NewMessageTool should return a tool") } diff --git a/pkg/tools/hardware/serial.go b/pkg/tools/hardware/serial.go new file mode 100644 index 000000000..7e197a909 --- /dev/null +++ b/pkg/tools/hardware/serial.go @@ -0,0 +1,453 @@ +package hardwaretools + +import ( + "context" + "encoding/json" + "fmt" + "math" + "regexp" + "runtime" + "strings" + "time" + "unicode/utf8" +) + +const ( + defaultSerialBaud = 115200 + defaultSerialDataBits = 8 + defaultSerialStopBits = 1 + defaultSerialTimeoutMS = 1000 + maxSerialPayloadBytes = 4096 + maxSerialReadBytes = 4096 + serialPollInterval = 100 * time.Millisecond +) + +var ( + unixSerialPortPattern = regexp.MustCompile( + `^(?:/dev/)?(?:ttyS\d+|ttyUSB\d+|ttyACM\d+|ttyAMA\d+|rfcomm\d+|tty\.[A-Za-z0-9._-]+|cu\.[A-Za-z0-9._-]+)$`, + ) + windowsSerialPortPattern = regexp.MustCompile(`^(?:\\\\\.\\)?COM[1-9]\d*$`) + unixSerialBaudRates = map[int]struct{}{ + 50: {}, 75: {}, 110: {}, 134: {}, 150: {}, 200: {}, 300: {}, 600: {}, 1200: {}, 1800: {}, + 2400: {}, 4800: {}, 9600: {}, 19200: {}, 38400: {}, 57600: {}, 115200: {}, 230400: {}, + } +) + +type SerialTool struct{} + +type serialPortInfo struct { + Name string `json:"name"` + Path string `json:"path"` +} + +type serialConfig struct { + Port string + Baud int + DataBits int + Parity string + StopBits int +} + +func NewSerialTool() *SerialTool { + return &SerialTool{} +} + +func (t *SerialTool) Name() string { + return "serial" +} + +func (t *SerialTool) Description() string { + return "Interact with host serial ports. Actions: list (enumerate ports), read (receive bytes), write (send bytes with explicit confirmation). Supports Linux, macOS, and Windows." +} + +func (t *SerialTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"list", "read", "write"}, + "description": "Action to perform: list available serial ports, read bytes from a port, or write bytes to a port.", + }, + "port": map[string]any{ + "type": "string", + "description": "Serial port path or name, for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3. Required for read/write.", + }, + "baud": map[string]any{ + "type": "integer", + "description": "Baud rate. Default: 115200. Linux/macOS currently support standard termios rates up to 230400; Windows accepts configured rates up to 4000000.", + }, + "data_bits": map[string]any{ + "type": "integer", + "description": "Data bits. Supported values: 5, 6, 7, 8. Default: 8.", + }, + "parity": map[string]any{ + "type": "string", + "enum": []string{"none", "even", "odd"}, + "description": "Parity mode. Default: none.", + }, + "stop_bits": map[string]any{ + "type": "integer", + "description": "Stop bits. Supported values: 1, 2. Default: 1.", + }, + "timeout_ms": map[string]any{ + "type": "integer", + "description": "Read/write timeout in milliseconds. Default: 1000.", + }, + "length": map[string]any{ + "type": "integer", + "description": "Number of bytes to read. Required for read. Range: 1-4096.", + }, + "data": map[string]any{ + "type": "array", + "items": map[string]any{"type": "integer"}, + "description": "Bytes to write, each in range 0-255. Required for write unless text is provided.", + }, + "text": map[string]any{ + "type": "string", + "description": "UTF-8 text to write. Required for write if data is omitted.", + }, + "confirm": map[string]any{ + "type": "boolean", + "description": "Must be true for write operations.", + }, + }, + "required": []string{"action"}, + } +} + +func (t *SerialTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + action, ok := args["action"].(string) + if !ok || strings.TrimSpace(action) == "" { + return ErrorResult("action is required") + } + + switch action { + case "list": + return t.list() + case "read": + return t.read(ctx, args) + case "write": + return t.write(ctx, args) + default: + return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, read, write)", action)) + } +} + +func (t *SerialTool) list() *ToolResult { + ports, err := serialListPorts() + if err != nil { + return ErrorResult(fmt.Sprintf("failed to list serial ports: %v", err)) + } + if len(ports) == 0 { + return SilentResult("No serial ports found on this host.") + } + + result, _ := json.MarshalIndent(map[string]any{ + "ports": ports, + "count": len(ports), + }, "", " ") + return SilentResult(string(result)) +} + +func (t *SerialTool) read(ctx context.Context, args map[string]any) *ToolResult { + cfg, errResult := parseSerialConfig(args) + if errResult != nil { + return errResult + } + + length := 0 + if v, ok := args["length"].(float64); ok { + length = int(v) + } + if length < 1 || length > maxSerialReadBytes { + return ErrorResult(fmt.Sprintf("length is required for read (1-%d)", maxSerialReadBytes)) + } + + timeout, errResult := parseSerialTimeout(args) + if errResult != nil { + return errResult + } + + data, err := serialRead(ctx, cfg, length, timeout) + if err != nil { + return ErrorResult(fmt.Sprintf("serial read failed on %s: %v", cfg.Port, err)) + } + + return SilentResult(formatSerialPayload("read", cfg, data, timeout)) +} + +func (t *SerialTool) write(ctx context.Context, args map[string]any) *ToolResult { + confirm, _ := args["confirm"].(bool) + if !confirm { + return ErrorResult( + "write operations require confirm: true. Please confirm with the user before sending bytes to a serial device.", + ) + } + + cfg, errResult := parseSerialConfig(args) + if errResult != nil { + return errResult + } + timeout, errResult := parseSerialTimeout(args) + if errResult != nil { + return errResult + } + payload, errResult := parseSerialWritePayload(args) + if errResult != nil { + return errResult + } + + written, err := serialWrite(ctx, cfg, payload, timeout) + if err != nil { + return ErrorResult(fmt.Sprintf("serial write failed on %s: %v", cfg.Port, err)) + } + + result, _ := json.MarshalIndent(map[string]any{ + "action": "write", + "port": cfg.Port, + "baud": cfg.Baud, + "data_bits": cfg.DataBits, + "parity": cfg.Parity, + "stop_bits": cfg.StopBits, + "timeout_ms": timeout.Milliseconds(), + "written": written, + "payload": serialPayloadSummary(payload), + }, "", " ") + return SilentResult(string(result)) +} + +func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) { + port, ok := args["port"].(string) + port = strings.TrimSpace(port) + if !ok || port == "" { + return serialConfig{}, ErrorResult( + "port is required (for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3)", + ) + } + + normalizedPort, err := normalizeSerialPort(port) + if err != nil { + return serialConfig{}, ErrorResult(err.Error()) + } + + cfg := serialConfig{ + Port: normalizedPort, + Baud: defaultSerialBaud, + DataBits: defaultSerialDataBits, + Parity: "none", + StopBits: defaultSerialStopBits, + } + + if v, ok := args["baud"].(float64); ok { + cfg.Baud = int(v) + } + if err := validateSerialBaud(cfg.Baud); err != nil { + return serialConfig{}, ErrorResult(err.Error()) + } + + if v, ok := args["data_bits"].(float64); ok { + cfg.DataBits = int(v) + } + switch cfg.DataBits { + case 5, 6, 7, 8: + default: + return serialConfig{}, ErrorResult("data_bits must be one of 5, 6, 7, or 8") + } + + if v, ok := args["parity"].(string); ok && strings.TrimSpace(v) != "" { + cfg.Parity = strings.ToLower(strings.TrimSpace(v)) + } + switch cfg.Parity { + case "none", "even", "odd": + default: + return serialConfig{}, ErrorResult(`parity must be one of "none", "even", or "odd"`) + } + + if v, ok := args["stop_bits"].(float64); ok { + cfg.StopBits = int(v) + } + if cfg.StopBits != 1 && cfg.StopBits != 2 { + return serialConfig{}, ErrorResult("stop_bits must be 1 or 2") + } + + return cfg, nil +} + +func parseSerialTimeout(args map[string]any) (time.Duration, *ToolResult) { + timeoutMS := defaultSerialTimeoutMS + if v, ok := args["timeout_ms"].(float64); ok { + timeoutMS = int(v) + } + if timeoutMS < 1 || timeoutMS > 60000 { + return 0, ErrorResult("timeout_ms must be between 1 and 60000") + } + return time.Duration(timeoutMS) * time.Millisecond, nil +} + +func parseSerialWritePayload(args map[string]any) ([]byte, *ToolResult) { + if text, ok := args["text"].(string); ok && text != "" { + if !utf8.ValidString(text) { + return nil, ErrorResult("text must be valid UTF-8") + } + if len(text) > maxSerialPayloadBytes { + return nil, ErrorResult(fmt.Sprintf("text payload too large: maximum %d bytes", maxSerialPayloadBytes)) + } + return []byte(text), nil + } + + dataRaw, ok := args["data"].([]any) + if !ok || len(dataRaw) == 0 { + return nil, ErrorResult("write requires either text or data") + } + if len(dataRaw) > maxSerialPayloadBytes { + return nil, ErrorResult(fmt.Sprintf("data too long: maximum %d bytes", maxSerialPayloadBytes)) + } + + data := make([]byte, len(dataRaw)) + for i, v := range dataRaw { + f, ok := v.(float64) + if !ok { + return nil, ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) + } + if f != math.Trunc(f) { + return nil, ErrorResult(fmt.Sprintf("data[%d] is not an integer byte value", i)) + } + b := int(f) + if b < 0 || b > 255 { + return nil, ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) + } + data[i] = byte(b) + } + + return data, nil +} + +func formatSerialPayload(action string, cfg serialConfig, data []byte, timeout time.Duration) string { + result, _ := json.MarshalIndent(map[string]any{ + "action": action, + "port": cfg.Port, + "baud": cfg.Baud, + "data_bits": cfg.DataBits, + "parity": cfg.Parity, + "stop_bits": cfg.StopBits, + "timeout_ms": timeout.Milliseconds(), + "payload": serialPayloadSummary(data), + }, "", " ") + return string(result) +} + +func serialPayloadSummary(data []byte) map[string]any { + hexValues := make([]string, len(data)) + intValues := make([]int, len(data)) + for i, b := range data { + hexValues[i] = fmt.Sprintf("0x%02x", b) + intValues[i] = int(b) + } + + summary := map[string]any{ + "length": len(data), + "bytes": intValues, + "hex": hexValues, + } + if utf8.Valid(data) { + summary["text"] = string(data) + } + return summary +} + +func normalizeSerialPort(port string) (string, error) { + switch runtime.GOOS { + case "windows": + return normalizeWindowsSerialPath(port) + case "linux", "darwin": + return normalizeUnixSerialPath(port) + default: + if normalized, err := normalizeUnixSerialPath(port); err == nil { + return normalized, nil + } + return normalizeWindowsSerialPath(port) + } +} + +func normalizeUnixSerialPath(port string) (string, error) { + trimmed := strings.TrimSpace(port) + if !unixSerialPortPattern.MatchString(trimmed) { + return "", fmt.Errorf( + "invalid serial port: expected a safe Unix device name such as /dev/ttyUSB0 or /dev/cu.usbserial-0001", + ) + } + if strings.HasPrefix(trimmed, "/dev/") { + return trimmed, nil + } + return "/dev/" + trimmed, nil +} + +func normalizeWindowsSerialPath(port string) (string, error) { + trimmed := strings.ToUpper(strings.TrimSpace(port)) + if !windowsSerialPortPattern.MatchString(trimmed) { + return "", fmt.Errorf("invalid serial port: expected a COM port such as COM3") + } + if strings.HasPrefix(trimmed, `\\.\`) { + return trimmed, nil + } + return `\\.\` + trimmed, nil +} + +func validateSerialBaud(baud int) error { + if baud < 50 || baud > 4000000 { + return fmt.Errorf("baud must be between 50 and 4000000") + } + + switch runtime.GOOS { + case "linux", "darwin": + if _, ok := unixSerialBaudRates[baud]; !ok { + return fmt.Errorf("unsupported baud rate on this platform: %d (supported up to 230400)", baud) + } + } + + return nil +} + +func serialContextErr(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} + +func serialWriteAll( + ctx context.Context, + data []byte, + timeout time.Duration, + now func() time.Time, + write func([]byte) (int, error), +) (int, error) { + if err := serialContextErr(ctx); err != nil { + return 0, err + } + + total := 0 + deadline := now().Add(timeout) + for total < len(data) { + if err := serialContextErr(ctx); err != nil { + return total, err + } + if deadline.Sub(now()) <= 0 { + return total, fmt.Errorf("timeout while writing serial data") + } + + n, err := write(data[total:]) + total += n + if err != nil { + return total, err + } + if n == 0 { + continue + } + } + + return total, nil +} diff --git a/pkg/tools/hardware/serial_darwin.go b/pkg/tools/hardware/serial_darwin.go new file mode 100644 index 000000000..bc019029e --- /dev/null +++ b/pkg/tools/hardware/serial_darwin.go @@ -0,0 +1,19 @@ +//go:build darwin + +package hardwaretools + +import "golang.org/x/sys/unix" + +func serialGetTermios(fd int) (*unix.Termios, error) { + return unix.IoctlGetTermios(fd, unix.TIOCGETA) +} + +func serialSetSpeed(tio *unix.Termios, speed uint32) error { + tio.Ispeed = uint64(speed) + tio.Ospeed = uint64(speed) + return nil +} + +func serialSetTermios(fd int, tio *unix.Termios) error { + return unix.IoctlSetTermios(fd, unix.TIOCSETA, tio) +} diff --git a/pkg/tools/hardware/serial_linux.go b/pkg/tools/hardware/serial_linux.go new file mode 100644 index 000000000..bad3e4cb8 --- /dev/null +++ b/pkg/tools/hardware/serial_linux.go @@ -0,0 +1,19 @@ +//go:build linux + +package hardwaretools + +import "golang.org/x/sys/unix" + +func serialGetTermios(fd int) (*unix.Termios, error) { + return unix.IoctlGetTermios(fd, unix.TCGETS) +} + +func serialSetSpeed(tio *unix.Termios, speed uint32) error { + tio.Ispeed = speed + tio.Ospeed = speed + return nil +} + +func serialSetTermios(fd int, tio *unix.Termios) error { + return unix.IoctlSetTermios(fd, unix.TCSETS, tio) +} diff --git a/pkg/tools/hardware/serial_other.go b/pkg/tools/hardware/serial_other.go new file mode 100644 index 000000000..ec72a2d2a --- /dev/null +++ b/pkg/tools/hardware/serial_other.go @@ -0,0 +1,21 @@ +//go:build !linux && !darwin && !windows + +package hardwaretools + +import ( + "context" + "fmt" + "time" +) + +func serialListPorts() ([]serialPortInfo, error) { + return nil, fmt.Errorf("serial is not supported on this platform") +} + +func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) { + return nil, fmt.Errorf("serial is not supported on this platform") +} + +func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) { + return 0, fmt.Errorf("serial is not supported on this platform") +} diff --git a/pkg/tools/hardware/serial_other_test.go b/pkg/tools/hardware/serial_other_test.go new file mode 100644 index 000000000..ef04c4062 --- /dev/null +++ b/pkg/tools/hardware/serial_other_test.go @@ -0,0 +1,18 @@ +//go:build !linux && !darwin && !windows + +package hardwaretools + +import ( + "strings" + "testing" +) + +func TestSerialListPortsUnsupportedPlatform(t *testing.T) { + _, err := serialListPorts() + if err == nil { + t.Fatal("expected unsupported platform error") + } + if !strings.Contains(err.Error(), "not supported") { + t.Fatalf("serialListPorts() error = %v, want unsupported platform message", err) + } +} diff --git a/pkg/tools/hardware/serial_test.go b/pkg/tools/hardware/serial_test.go new file mode 100644 index 000000000..6b2e9765d --- /dev/null +++ b/pkg/tools/hardware/serial_test.go @@ -0,0 +1,269 @@ +package hardwaretools + +import ( + "context" + "runtime" + "strings" + "testing" + "time" +) + +func TestParseSerialConfig(t *testing.T) { + port := "/dev/ttyUSB0" + if runtime.GOOS == "windows" { + port = "COM3" + } + + cfg, errResult := parseSerialConfig(map[string]any{ + "port": port, + "baud": float64(9600), + "data_bits": float64(7), + "parity": "even", + "stop_bits": float64(2), + }) + if errResult != nil { + t.Fatalf("parseSerialConfig() unexpected error = %v", errResult.ForLLM) + } + + wantPort := "/dev/ttyUSB0" + if runtime.GOOS == "windows" { + wantPort = `\\.\COM3` + } + if cfg.Port != wantPort || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 { + t.Fatalf("parseSerialConfig() = %#v", cfg) + } +} + +func TestParseSerialConfigRejectsInvalidParity(t *testing.T) { + port := "/dev/ttyUSB0" + if runtime.GOOS == "windows" { + port = "COM3" + } + + _, errResult := parseSerialConfig(map[string]any{ + "port": port, + "parity": "mark", + }) + if errResult == nil { + t.Fatal("expected invalid parity to fail") + } +} + +func TestParseSerialConfigRejectsUnsupportedUnixBaud(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skip("Unix baud validation only applies on Unix platforms") + } + + _, errResult := parseSerialConfig(map[string]any{ + "port": "/dev/ttyUSB0", + "baud": float64(460800), + }) + if errResult == nil { + t.Fatal("expected unsupported Unix baud rate to fail") + } +} + +func TestParseSerialWritePayloadRejectsFractionalBytes(t *testing.T) { + _, errResult := parseSerialWritePayload(map[string]any{ + "data": []any{65.9}, + }) + if errResult == nil { + t.Fatal("expected fractional byte value to fail") + } +} + +func TestValidateSerialBaud(t *testing.T) { + tests := []struct { + name string + baud int + wantErr bool + }{ + {name: "default-supported", baud: 115200}, + {name: "max-unix-supported", baud: 230400}, + {name: "too-low", baud: 49, wantErr: true}, + {name: "too-high", baud: 4000001, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateSerialBaud(tt.baud) + if (err != nil) != tt.wantErr { + t.Fatalf("validateSerialBaud(%d) error = %v, wantErr %v", tt.baud, err, tt.wantErr) + } + }) + } +} + +func TestSerialReadCanceledBeforeOpen(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + port := "/dev/ttyUSB0" + if runtime.GOOS == "windows" { + port = "COM3" + } + + _, err := serialRead( + ctx, + serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1}, + 1, + time.Second, + ) + if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) { + t.Fatalf("serialRead() error = %v, want context canceled", err) + } +} + +func TestSerialWriteCanceledBeforeOpen(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + port := "/dev/ttyUSB0" + if runtime.GOOS == "windows" { + port = "COM3" + } + + _, err := serialWrite( + ctx, + serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1}, + []byte("AT"), + time.Second, + ) + if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) { + t.Fatalf("serialWrite() error = %v, want context canceled", err) + } +} + +func TestParseSerialConfigRejectsUnsafePortPaths(t *testing.T) { + tests := []string{ + "../../../etc/passwd", + "/etc/passwd", + `C:\temp\device.txt`, + `\\.\C:\temp\device.txt`, + } + + for _, port := range tests { + t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) { + _, errResult := parseSerialConfig(map[string]any{ + "port": port, + }) + if errResult == nil { + t.Fatalf("expected unsafe port %q to be rejected", port) + } + }) + } +} + +func TestNormalizeUnixSerialPath(t *testing.T) { + tests := []struct { + port string + want string + }{ + {port: "ttyUSB0", want: "/dev/ttyUSB0"}, + {port: "/dev/ttyACM0", want: "/dev/ttyACM0"}, + {port: "/dev/cu.usbserial-0001", want: "/dev/cu.usbserial-0001"}, + } + + for _, tt := range tests { + got, err := normalizeUnixSerialPath(tt.port) + if err != nil { + t.Fatalf("normalizeUnixSerialPath(%q) unexpected error = %v", tt.port, err) + } + if got != tt.want { + t.Fatalf("normalizeUnixSerialPath(%q) = %q, want %q", tt.port, got, tt.want) + } + } +} + +func TestNormalizeUnixSerialPathRejectsInvalidNames(t *testing.T) { + tests := []string{ + "", + "ttyUSB0/../../passwd", + "/dev/../../etc/passwd", + "/tmp/ttyUSB0", + "ttyUSB", + "COM3", + } + + for _, port := range tests { + t.Run(strings.ReplaceAll(port, "/", "_"), func(t *testing.T) { + if _, err := normalizeUnixSerialPath(port); err == nil { + t.Fatalf("expected %q to be rejected", port) + } + }) + } +} + +func TestNormalizeWindowsSerialPath(t *testing.T) { + tests := []struct { + port string + want string + }{ + {port: "COM3", want: `\\.\COM3`}, + {port: "com12", want: `\\.\COM12`}, + {port: `\\.\COM7`, want: `\\.\COM7`}, + } + + for _, tt := range tests { + got, err := normalizeWindowsSerialPath(tt.port) + if err != nil { + t.Fatalf("normalizeWindowsSerialPath(%q) unexpected error = %v", tt.port, err) + } + if got != tt.want { + t.Fatalf("normalizeWindowsSerialPath(%q) = %q, want %q", tt.port, got, tt.want) + } + } +} + +func TestNormalizeWindowsSerialPathRejectsInvalidNames(t *testing.T) { + tests := []string{ + "", + "COM0", + "COM", + "/dev/ttyUSB0", + `C:\temp\device.txt`, + `\\.\C:\temp\device.txt`, + `\\server\share\COM3`, + } + + for _, port := range tests { + t.Run(strings.ReplaceAll(strings.ReplaceAll(port, `\`, "_"), "/", "_"), func(t *testing.T) { + if _, err := normalizeWindowsSerialPath(port); err == nil { + t.Fatalf("expected %q to be rejected", port) + } + }) + } +} + +func TestParseSerialTimeout(t *testing.T) { + timeout, errResult := parseSerialTimeout(map[string]any{ + "timeout_ms": float64(2500), + }) + if errResult != nil { + t.Fatalf("parseSerialTimeout() unexpected error = %v", errResult.ForLLM) + } + if timeout != 2500*time.Millisecond { + t.Fatalf("timeout = %v, want 2500ms", timeout) + } +} + +func TestParseSerialWritePayloadSupportsText(t *testing.T) { + data, errResult := parseSerialWritePayload(map[string]any{ + "text": "AT\r\n", + }) + if errResult != nil { + t.Fatalf("parseSerialWritePayload() unexpected error = %v", errResult.ForLLM) + } + if string(data) != "AT\r\n" { + t.Fatalf("payload = %q, want %q", string(data), "AT\r\n") + } +} + +func TestParseSerialWritePayloadRejectsOutOfRangeByte(t *testing.T) { + _, errResult := parseSerialWritePayload(map[string]any{ + "data": []any{float64(256)}, + }) + if errResult == nil { + t.Fatal("expected payload validation failure") + } +} diff --git a/pkg/tools/hardware/serial_unix.go b/pkg/tools/hardware/serial_unix.go new file mode 100644 index 000000000..548b8573b --- /dev/null +++ b/pkg/tools/hardware/serial_unix.go @@ -0,0 +1,286 @@ +//go:build linux || darwin + +package hardwaretools + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sort" + "time" + + "golang.org/x/sys/unix" +) + +var ( + unixSerialNow = time.Now + unixSerialOpenPort = openAndConfigureSerialPort + unixSerialClosePort = unix.Close + unixSerialPollRead = pollRead + unixSerialPollWrite = pollWrite +) + +func serialListPorts() ([]serialPortInfo, error) { + patterns := []string{ + "/dev/ttyS*", + "/dev/ttyUSB*", + "/dev/ttyACM*", + "/dev/ttyAMA*", + "/dev/rfcomm*", + "/dev/tty.*", + "/dev/cu.*", + } + + seen := make(map[string]struct{}) + ports := make([]serialPortInfo, 0) + for _, pattern := range patterns { + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + for _, match := range matches { + if _, ok := seen[match]; ok { + continue + } + info, err := os.Stat(match) + if err != nil || info.IsDir() { + continue + } + seen[match] = struct{}{} + ports = append(ports, serialPortInfo{ + Name: filepath.Base(match), + Path: match, + }) + } + } + + sort.Slice(ports, func(i, j int) bool { + return ports[i].Path < ports[j].Path + }) + return ports, nil +} + +func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) { + if err := serialContextErr(ctx); err != nil { + return nil, err + } + + fd, err := unixSerialOpenPort(cfg) + if err != nil { + return nil, err + } + defer unixSerialClosePort(fd) + + buf := make([]byte, length) + total := 0 + deadline := unixSerialNow().Add(timeout) + + for total < length { + if err := serialContextErr(ctx); err != nil { + return nil, err + } + + remaining := deadline.Sub(unixSerialNow()) + if remaining <= 0 { + break + } + + n, err := unixSerialPollRead(fd, buf[total:], minSerialPollTimeout(remaining)) + if err != nil { + return nil, err + } + if n == 0 { + continue + } + total += n + } + + return buf[:total], nil +} + +func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) { + if err := serialContextErr(ctx); err != nil { + return 0, err + } + + fd, err := unixSerialOpenPort(cfg) + if err != nil { + return 0, err + } + defer unixSerialClosePort(fd) + + total := 0 + deadline := unixSerialNow().Add(timeout) + for total < len(data) { + if err := serialContextErr(ctx); err != nil { + return total, err + } + + remaining := deadline.Sub(unixSerialNow()) + if remaining <= 0 { + return total, fmt.Errorf("timeout while writing serial data") + } + + n, err := unixSerialPollWrite(fd, data[total:], minSerialPollTimeout(remaining)) + if err != nil { + return total, err + } + if n == 0 { + continue + } + total += n + } + + return total, nil +} + +func openAndConfigureSerialPort(cfg serialConfig) (int, error) { + fd, err := unix.Open(cfg.Port, unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0) + if err != nil { + return -1, err + } + + if err := unix.SetNonblock(fd, false); err != nil { + unix.Close(fd) + return -1, err + } + + if err := configureUnixSerialPort(fd, cfg); err != nil { + unix.Close(fd) + return -1, err + } + + return fd, nil +} + +func configureUnixSerialPort(fd int, cfg serialConfig) error { + tio, err := serialGetTermios(fd) + if err != nil { + return err + } + + tio.Iflag = 0 + tio.Oflag = 0 + tio.Lflag = 0 + tio.Cflag = unix.CREAD | unix.CLOCAL + tio.Cc[unix.VMIN] = 0 + tio.Cc[unix.VTIME] = 0 + + switch cfg.DataBits { + case 5: + tio.Cflag |= unix.CS5 + case 6: + tio.Cflag |= unix.CS6 + case 7: + tio.Cflag |= unix.CS7 + default: + tio.Cflag |= unix.CS8 + } + + switch cfg.Parity { + case "even": + tio.Cflag |= unix.PARENB + case "odd": + tio.Cflag |= unix.PARENB | unix.PARODD + } + + if cfg.StopBits == 2 { + tio.Cflag |= unix.CSTOPB + } + + speed, err := serialBaudToUnix(cfg.Baud) + if err != nil { + return err + } + if err := serialSetSpeed(tio, speed); err != nil { + return err + } + + return serialSetTermios(fd, tio) +} + +func serialBaudToUnix(baud int) (uint32, error) { + switch baud { + case 50: + return unix.B50, nil + case 75: + return unix.B75, nil + case 110: + return unix.B110, nil + case 134: + return unix.B134, nil + case 150: + return unix.B150, nil + case 200: + return unix.B200, nil + case 300: + return unix.B300, nil + case 600: + return unix.B600, nil + case 1200: + return unix.B1200, nil + case 1800: + return unix.B1800, nil + case 2400: + return unix.B2400, nil + case 4800: + return unix.B4800, nil + case 9600: + return unix.B9600, nil + case 19200: + return unix.B19200, nil + case 38400: + return unix.B38400, nil + case 57600: + return unix.B57600, nil + case 115200: + return unix.B115200, nil + case 230400: + return unix.B230400, nil + default: + return 0, fmt.Errorf("unsupported baud rate on this platform: %d", baud) + } +} + +func pollRead(fd int, dst []byte, timeout time.Duration) (int, error) { + pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}} + n, err := unix.Poll(pfd, durationToPollTimeout(timeout)) + if err != nil { + return 0, err + } + if n == 0 { + return 0, nil + } + return unix.Read(fd, dst) +} + +func pollWrite(fd int, src []byte, timeout time.Duration) (int, error) { + pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLOUT}} + n, err := unix.Poll(pfd, durationToPollTimeout(timeout)) + if err != nil { + return 0, err + } + if n == 0 { + return 0, nil + } + return unix.Write(fd, src) +} + +func durationToPollTimeout(timeout time.Duration) int { + if timeout <= 0 { + return 0 + } + ms := int(timeout / time.Millisecond) + if ms == 0 { + return 1 + } + return ms +} + +func minSerialPollTimeout(timeout time.Duration) time.Duration { + if timeout > serialPollInterval { + return serialPollInterval + } + return timeout +} diff --git a/pkg/tools/hardware/serial_unix_test.go b/pkg/tools/hardware/serial_unix_test.go new file mode 100644 index 000000000..fac2efe7f --- /dev/null +++ b/pkg/tools/hardware/serial_unix_test.go @@ -0,0 +1,140 @@ +//go:build linux || darwin + +package hardwaretools + +import ( + "context" + "errors" + "testing" + "time" +) + +func stubUnixSerialIO(t *testing.T, now *time.Time) { + t.Helper() + + prevNow := unixSerialNow + prevOpen := unixSerialOpenPort + prevClose := unixSerialClosePort + prevPollRead := unixSerialPollRead + prevPollWrite := unixSerialPollWrite + + unixSerialNow = func() time.Time { + return *now + } + unixSerialOpenPort = func(cfg serialConfig) (int, error) { + return 42, nil + } + unixSerialClosePort = func(fd int) error { + return nil + } + unixSerialPollRead = prevPollRead + unixSerialPollWrite = prevPollWrite + + t.Cleanup(func() { + unixSerialNow = prevNow + unixSerialOpenPort = prevOpen + unixSerialClosePort = prevClose + unixSerialPollRead = prevPollRead + unixSerialPollWrite = prevPollWrite + }) +} + +func TestSerialReadWaitsPastEmptyPollsUntilDeadline(t *testing.T) { + now := time.Unix(0, 0) + stubUnixSerialIO(t, &now) + + pollCalls := 0 + unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) { + pollCalls++ + if timeout > serialPollInterval { + t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval) + } + now = now.Add(timeout) + if pollCalls < 4 { + return 0, nil + } + return copy(dst, []byte("OK")), nil + } + + got, err := serialRead(context.Background(), serialConfig{}, 2, 500*time.Millisecond) + if err != nil { + t.Fatalf("serialRead() error = %v", err) + } + if string(got) != "OK" { + t.Fatalf("serialRead() = %q, want %q", got, "OK") + } + if pollCalls != 4 { + t.Fatalf("poll calls = %d, want 4", pollCalls) + } +} + +func TestSerialReadReturnsPromptlyOnContextCancelBetweenPolls(t *testing.T) { + now := time.Unix(0, 0) + stubUnixSerialIO(t, &now) + + ctx, cancel := context.WithCancel(context.Background()) + pollCalls := 0 + unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) { + pollCalls++ + now = now.Add(timeout) + cancel() + return 0, nil + } + + _, err := serialRead(ctx, serialConfig{}, 1, time.Second) + if !errors.Is(err, context.Canceled) { + t.Fatalf("serialRead() error = %v, want context canceled", err) + } + if pollCalls != 1 { + t.Fatalf("poll calls = %d, want 1", pollCalls) + } +} + +func TestSerialWriteWaitsPastEmptyPollsUntilReady(t *testing.T) { + now := time.Unix(0, 0) + stubUnixSerialIO(t, &now) + + pollCalls := 0 + unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) { + pollCalls++ + if timeout > serialPollInterval { + t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval) + } + now = now.Add(timeout) + switch pollCalls { + case 1, 2: + return 0, nil + default: + return 1, nil + } + } + + written, err := serialWrite(context.Background(), serialConfig{}, []byte("OK"), 500*time.Millisecond) + if err != nil { + t.Fatalf("serialWrite() error = %v", err) + } + if written != 2 { + t.Fatalf("serialWrite() wrote %d bytes, want 2", written) + } + if pollCalls != 4 { + t.Fatalf("poll calls = %d, want 4", pollCalls) + } +} + +func TestSerialWriteTimesOutAfterRepeatedEmptyPolls(t *testing.T) { + now := time.Unix(0, 0) + stubUnixSerialIO(t, &now) + + unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) { + now = now.Add(timeout) + return 0, nil + } + + written, err := serialWrite(context.Background(), serialConfig{}, []byte("A"), 250*time.Millisecond) + if err == nil || err.Error() != "timeout while writing serial data" { + t.Fatalf("serialWrite() error = %v, want timeout", err) + } + if written != 0 { + t.Fatalf("serialWrite() wrote %d bytes, want 0", written) + } +} diff --git a/pkg/tools/hardware/serial_windows.go b/pkg/tools/hardware/serial_windows.go new file mode 100644 index 000000000..36f8b4271 --- /dev/null +++ b/pkg/tools/hardware/serial_windows.go @@ -0,0 +1,248 @@ +//go:build windows + +package hardwaretools + +import ( + "context" + "fmt" + "sort" + "strings" + "time" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procGetCommState = kernel32.NewProc("GetCommState") + procSetCommState = kernel32.NewProc("SetCommState") + procSetCommTimeouts = kernel32.NewProc("SetCommTimeouts") + procPurgeComm = kernel32.NewProc("PurgeComm") +) + +const ( + purgeTxClear = 0x0004 + purgeRxClear = 0x0008 + + dcbFlagBinary = 0x00000001 + dcbFlagParity = 0x00000002 + dcbFlagOutxCtsFlow = 0x00000004 + dcbFlagOutxDsrFlow = 0x00000008 + dcbFlagDtrControlMask = 0x00000030 + dcbFlagDsrSensitivity = 0x00000040 + dcbFlagTXContinueOnXoff = 0x00000080 + dcbFlagOutX = 0x00000100 + dcbFlagInX = 0x00000200 + dcbFlagRtsControlMask = 0x00003000 +) + +type dcb struct { + DCBlength uint32 + BaudRate uint32 + Flags uint32 + Reserved uint16 + XonLim uint16 + XoffLim uint16 + ByteSize byte + Parity byte + StopBits byte + XonChar byte + XoffChar byte + ErrorChar byte + EofChar byte + EvtChar byte + wReserved1 uint16 +} + +type commTimeouts struct { + ReadIntervalTimeout uint32 + ReadTotalTimeoutMultiplier uint32 + ReadTotalTimeoutConstant uint32 + WriteTotalTimeoutMultiplier uint32 + WriteTotalTimeoutConstant uint32 +} + +func serialListPorts() ([]serialPortInfo, error) { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, `HARDWARE\DEVICEMAP\SERIALCOMM`, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return nil, nil + } + return nil, err + } + defer key.Close() + + names, err := key.ReadValueNames(-1) + if err != nil { + return nil, err + } + + ports := make([]serialPortInfo, 0, len(names)) + seen := make(map[string]struct{}) + for _, name := range names { + value, _, err := key.GetStringValue(name) + if err != nil { + continue + } + portName := strings.TrimSpace(value) + if portName == "" { + continue + } + normalized := strings.ToUpper(portName) + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + ports = append(ports, serialPortInfo{ + Name: normalized, + Path: normalized, + }) + } + + sort.Slice(ports, func(i, j int) bool { + return ports[i].Path < ports[j].Path + }) + return ports, nil +} + +func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) { + if err := serialContextErr(ctx); err != nil { + return nil, err + } + + handle, err := openAndConfigureWindowsSerial(cfg, timeout) + if err != nil { + return nil, err + } + defer windows.CloseHandle(handle) + + if err := serialContextErr(ctx); err != nil { + return nil, err + } + + buf := make([]byte, length) + var read uint32 + // Synchronous serial I/O on Windows cannot be interrupted once the syscall starts. + // COMMTIMEOUTS bounds how long turn cancellation may take to surface. + if err := windows.ReadFile(handle, buf, &read, nil); err != nil { + return nil, err + } + return buf[:read], nil +} + +func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) { + if err := serialContextErr(ctx); err != nil { + return 0, err + } + + handle, err := openAndConfigureWindowsSerial(cfg, timeout) + if err != nil { + return 0, err + } + defer windows.CloseHandle(handle) + + if err := serialContextErr(ctx); err != nil { + return 0, err + } + + return serialWriteAll(ctx, data, timeout, time.Now, func(chunk []byte) (int, error) { + var written uint32 + // Like ReadFile above, this synchronous WriteFile call relies on COMMTIMEOUTS + // rather than context preemption once the syscall is in flight. + if err := windows.WriteFile(handle, chunk, &written, nil); err != nil { + return int(written), err + } + return int(written), nil + }) +} + +func openAndConfigureWindowsSerial(cfg serialConfig, timeout time.Duration) (windows.Handle, error) { + handle, err := windows.CreateFile( + windows.StringToUTF16Ptr(cfg.Port), + windows.GENERIC_READ|windows.GENERIC_WRITE, + 0, + nil, + windows.OPEN_EXISTING, + 0, + 0, + ) + if err != nil { + return 0, err + } + + if err := configureWindowsSerialPort(handle, cfg, timeout); err != nil { + windows.CloseHandle(handle) + return 0, err + } + return handle, nil +} + +func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout time.Duration) error { + state := dcb{DCBlength: uint32(unsafe.Sizeof(dcb{}))} + r1, _, err := procGetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state))) + if r1 == 0 { + return err + } + + state.BaudRate = uint32(cfg.Baud) + state.ByteSize = byte(cfg.DataBits) + state.Flags = sanitizeWindowsSerialFlags(state.Flags) + state.Flags |= dcbFlagBinary + + switch cfg.Parity { + case "even": + state.Parity = 2 + state.Flags |= dcbFlagParity + case "odd": + state.Parity = 1 + state.Flags |= dcbFlagParity + default: + state.Parity = 0 + state.Flags &^= dcbFlagParity + } + + switch cfg.StopBits { + case 2: + state.StopBits = 2 + default: + state.StopBits = 0 + } + + r1, _, err = procSetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state))) + if r1 == 0 { + return err + } + + timeoutMS := uint32(timeout / time.Millisecond) + if timeoutMS == 0 { + timeoutMS = 1 + } + timeouts := commTimeouts{ + ReadIntervalTimeout: timeoutMS, + ReadTotalTimeoutConstant: timeoutMS, + WriteTotalTimeoutConstant: timeoutMS, + ReadTotalTimeoutMultiplier: 0, + WriteTotalTimeoutMultiplier: 0, + } + r1, _, err = procSetCommTimeouts.Call(uintptr(handle), uintptr(unsafe.Pointer(&timeouts))) + if r1 == 0 { + return err + } + + procPurgeComm.Call(uintptr(handle), uintptr(purgeRxClear|purgeTxClear)) + return nil +} + +func sanitizeWindowsSerialFlags(flags uint32) uint32 { + flags &^= dcbFlagOutxCtsFlow | + dcbFlagOutxDsrFlow | + dcbFlagDtrControlMask | + dcbFlagDsrSensitivity | + dcbFlagTXContinueOnXoff | + dcbFlagOutX | + dcbFlagInX | + dcbFlagRtsControlMask + return flags +} diff --git a/pkg/tools/hardware/serial_windows_test.go b/pkg/tools/hardware/serial_windows_test.go new file mode 100644 index 000000000..ecb0addbd --- /dev/null +++ b/pkg/tools/hardware/serial_windows_test.go @@ -0,0 +1,39 @@ +//go:build windows + +package hardwaretools + +import "testing" + +func TestSanitizeWindowsSerialFlags(t *testing.T) { + flags := uint32( + dcbFlagBinary | + dcbFlagParity | + dcbFlagOutxCtsFlow | + dcbFlagOutxDsrFlow | + dcbFlagDtrControlMask | + dcbFlagDsrSensitivity | + dcbFlagTXContinueOnXoff | + dcbFlagOutX | + dcbFlagInX | + dcbFlagRtsControlMask, + ) + + got := sanitizeWindowsSerialFlags(flags) + + if got&dcbFlagBinary == 0 { + t.Fatal("sanitizeWindowsSerialFlags() should preserve fBinary") + } + if got&dcbFlagParity == 0 { + t.Fatal("sanitizeWindowsSerialFlags() should preserve fParity") + } + if got&(dcbFlagOutxCtsFlow| + dcbFlagOutxDsrFlow| + dcbFlagDtrControlMask| + dcbFlagDsrSensitivity| + dcbFlagTXContinueOnXoff| + dcbFlagOutX| + dcbFlagInX| + dcbFlagRtsControlMask) != 0 { + t.Fatalf("sanitizeWindowsSerialFlags() = %#x, want flow-control bits cleared", got) + } +} diff --git a/pkg/tools/hardware/serial_write_common_test.go b/pkg/tools/hardware/serial_write_common_test.go new file mode 100644 index 000000000..398c1fde5 --- /dev/null +++ b/pkg/tools/hardware/serial_write_common_test.go @@ -0,0 +1,87 @@ +package hardwaretools + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestSerialWriteAllRetriesPartialWritesUntilComplete(t *testing.T) { + now := time.Unix(0, 0) + calls := 0 + + written, err := serialWriteAll(context.Background(), []byte("PING"), time.Second, func() time.Time { + return now + }, func(chunk []byte) (int, error) { + calls++ + now = now.Add(100 * time.Millisecond) + switch calls { + case 1: + if string(chunk) != "PING" { + t.Fatalf("first chunk = %q, want %q", chunk, "PING") + } + return 2, nil + case 2: + if string(chunk) != "NG" { + t.Fatalf("second chunk = %q, want %q", chunk, "NG") + } + return 2, nil + default: + t.Fatalf("unexpected extra write call %d", calls) + return 0, nil + } + }) + if err != nil { + t.Fatalf("serialWriteAll() error = %v", err) + } + if written != 4 { + t.Fatalf("serialWriteAll() wrote %d bytes, want 4", written) + } +} + +func TestSerialWriteAllTimesOutAfterZeroByteWrites(t *testing.T) { + now := time.Unix(0, 0) + calls := 0 + + written, err := serialWriteAll(context.Background(), []byte("A"), 250*time.Millisecond, func() time.Time { + return now + }, func(chunk []byte) (int, error) { + calls++ + now = now.Add(100 * time.Millisecond) + return 0, nil + }) + if err == nil || err.Error() != "timeout while writing serial data" { + t.Fatalf("serialWriteAll() error = %v, want timeout", err) + } + if written != 0 { + t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written) + } + if calls != 3 { + t.Fatalf("write calls = %d, want 3", calls) + } +} + +func TestSerialWriteAllReturnsContextCancellationAfterRetryBoundary(t *testing.T) { + now := time.Unix(0, 0) + ctx, cancel := context.WithCancel(context.Background()) + calls := 0 + + written, err := serialWriteAll(ctx, []byte("A"), time.Second, func() time.Time { + return now + }, func(chunk []byte) (int, error) { + calls++ + now = now.Add(100 * time.Millisecond) + cancel() + return 0, nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("serialWriteAll() error = %v, want context canceled", err) + } + if written != 0 { + t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written) + } + if calls != 1 { + t.Fatalf("write calls = %d, want 1", calls) + } +} diff --git a/pkg/tools/hardware_facade.go b/pkg/tools/hardware_facade.go index f55d152cf..b505c5a48 100644 --- a/pkg/tools/hardware_facade.go +++ b/pkg/tools/hardware_facade.go @@ -3,8 +3,9 @@ package tools import hardwaretools "github.com/sipeed/picoclaw/pkg/tools/hardware" type ( - I2CTool = hardwaretools.I2CTool - SPITool = hardwaretools.SPITool + I2CTool = hardwaretools.I2CTool + SerialTool = hardwaretools.SerialTool + SPITool = hardwaretools.SPITool ) func NewI2CTool() *I2CTool { @@ -14,3 +15,7 @@ func NewI2CTool() *I2CTool { func NewSPITool() *SPITool { return hardwaretools.NewSPITool() } + +func NewSerialTool() *SerialTool { + return hardwaretools.NewSerialTool() +} diff --git a/web/backend/api/tools.go b/web/backend/api/tools.go index c6c2deaae..3476e3c53 100644 --- a/web/backend/api/tools.go +++ b/web/backend/api/tools.go @@ -171,6 +171,12 @@ var toolCatalog = []toolCatalogEntry{ Category: "hardware", ConfigKey: "spi", }, + { + Name: "serial", + Description: "Interact with serial ports exposed on the host.", + Category: "hardware", + ConfigKey: "serial", + }, { Name: "tool_search_tool_regex", Description: "Discover hidden MCP tools by regex search when tool discovery is enabled.", @@ -265,6 +271,8 @@ func buildToolSupport(cfg *config.Config) []toolSupportItem { status, reasonCode = resolveWebSearchToolSupport(cfg) case "i2c", "spi": status, reasonCode = resolveHardwareToolSupport(cfg.Tools.IsToolEnabled(entry.ConfigKey)) + case "serial": + status, reasonCode = resolveSerialToolSupport(cfg.Tools.IsToolEnabled(entry.ConfigKey)) default: if cfg.Tools.IsToolEnabled(entry.ConfigKey) { status = "enabled" @@ -293,6 +301,18 @@ func resolveHardwareToolSupport(enabled bool) (string, string) { return "enabled", "" } +func resolveSerialToolSupport(enabled bool) (string, string) { + if !enabled { + return "disabled", "" + } + switch runtime.GOOS { + case "linux", "darwin", "windows": + return "enabled", "" + default: + return "blocked", "requires_serial_platform" + } +} + func resolveDiscoveryToolSupport(cfg *config.Config, methodEnabled bool) (string, string) { if !cfg.Tools.IsToolEnabled("mcp") { return "disabled", "" @@ -362,6 +382,8 @@ func applyToolState(cfg *config.Config, toolName string, enabled bool) error { cfg.Tools.I2C.Enabled = enabled case "spi": cfg.Tools.SPI.Enabled = enabled + case "serial": + cfg.Tools.Serial.Enabled = enabled case "tool_search_tool_regex": cfg.Tools.MCP.Discovery.UseRegex = enabled if enabled { diff --git a/web/backend/api/tools_test.go b/web/backend/api/tools_test.go index c98067e41..a09a49fd6 100644 --- a/web/backend/api/tools_test.go +++ b/web/backend/api/tools_test.go @@ -92,9 +92,36 @@ func TestHandleListTools(t *testing.T) { if gotTools["i2c"].Status != "disabled" { t.Fatalf("i2c status = %q, want disabled on linux when config is off", gotTools["i2c"].Status) } + if gotTools["serial"].Status != "disabled" { + t.Fatalf("serial status = %q, want disabled when config is off", gotTools["serial"].Status) + } + + cfg.Tools.Serial.Enabled = true + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/tools", nil) + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + gotTools = make(map[string]toolSupportItem, len(resp.Tools)) + for _, tool := range resp.Tools { + gotTools[tool.Name] = tool + } + if gotTools["serial"].Status != "enabled" { + t.Fatalf("serial = %#v, want enabled on linux when config is on", gotTools["serial"]) + } } else { cfg.Tools.I2C.Enabled = true cfg.Tools.SPI.Enabled = true + cfg.Tools.Serial.Enabled = true if err := config.SaveConfig(configPath, cfg); err != nil { t.Fatalf("SaveConfig() error = %v", err) } @@ -120,6 +147,16 @@ func TestHandleListTools(t *testing.T) { if gotTools["spi"].Status != "blocked" || gotTools["spi"].ReasonCode != "requires_linux" { t.Fatalf("spi = %#v, want blocked/requires_linux", gotTools["spi"]) } + switch runtime.GOOS { + case "darwin", "windows": + if gotTools["serial"].Status != "enabled" { + t.Fatalf("serial = %#v, want enabled on supported host", gotTools["serial"]) + } + default: + if gotTools["serial"].Status != "blocked" || gotTools["serial"].ReasonCode != "requires_serial_platform" { + t.Fatalf("serial = %#v, want blocked/requires_serial_platform", gotTools["serial"]) + } + } } } @@ -195,6 +232,26 @@ func TestHandleUpdateToolState(t *testing.T) { if !updated.Tools.Cron.Enabled { t.Fatalf("cron should be enabled: %#v", updated.Tools.Cron) } + + rec4 := httptest.NewRecorder() + req4 := httptest.NewRequest( + http.MethodPut, + "/api/tools/serial/state", + bytes.NewBufferString(`{"enabled":true}`), + ) + req4.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec4, req4) + if rec4.Code != http.StatusOK { + t.Fatalf("serial status = %d, want %d, body=%s", rec4.Code, http.StatusOK, rec4.Body.String()) + } + + updated, err = config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig(updated serial) error = %v", err) + } + if !updated.Tools.Serial.Enabled { + t.Fatalf("serial should be enabled: %#v", updated.Tools.Serial) + } } func TestHandleListTools_ReportsWebSearchEnabledWhenToolIsOn(t *testing.T) { diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index 75a17e791..4e7a0c818 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -603,6 +603,7 @@ }, "reasons": { "requires_linux": "This tool only works on Linux hosts with the required device files exposed.", + "requires_serial_platform": "This tool currently supports Linux, macOS, and Windows hosts with accessible serial ports.", "requires_skills": "Enable `tools.skills` before this skill-registry tool can be used.", "requires_subagent": "Enable `tools.subagent` before the spawn tool can delegate work.", "requires_mcp_discovery": "Enable `tools.mcp.discovery` before MCP discovery tools become available.", diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index 0a140605a..fa7d56418 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -603,6 +603,7 @@ }, "reasons": { "requires_linux": "该工具仅在 Linux 主机上可用,并且需要暴露对应的设备文件。", + "requires_serial_platform": "该工具当前支持 Linux、macOS 和 Windows,且要求主机可访问对应串口。", "requires_skills": "需要先启用 `tools.skills`,该技能注册表工具才能使用。", "requires_subagent": "需要先启用 `tools.subagent`,`spawn` 才能委派任务。", "requires_mcp_discovery": "需要先启用 `tools.mcp.discovery`,MCP 发现工具才会可用。",