From 0f52076762692521013c36d4e77da9a3cf463465 Mon Sep 17 00:00:00 2001 From: SiYue-ZO <2835601846@qq.com> Date: Sun, 26 Apr 2026 12:43:50 +0800 Subject: [PATCH] feat(tools): add cross-platform serial hardware tool --- pkg/tools/facade_compat_test.go | 3 + pkg/tools/hardware/serial.go | 332 +++++++++++++++++++++++++++ pkg/tools/hardware/serial_darwin.go | 19 ++ pkg/tools/hardware/serial_linux.go | 19 ++ pkg/tools/hardware/serial_other.go | 20 ++ pkg/tools/hardware/serial_test.go | 66 ++++++ pkg/tools/hardware/serial_unix.go | 263 +++++++++++++++++++++ pkg/tools/hardware/serial_windows.go | 214 +++++++++++++++++ pkg/tools/hardware_facade.go | 9 +- 9 files changed, 943 insertions(+), 2 deletions(-) create mode 100644 pkg/tools/hardware/serial.go create mode 100644 pkg/tools/hardware/serial_darwin.go create mode 100644 pkg/tools/hardware/serial_linux.go create mode 100644 pkg/tools/hardware/serial_other.go create mode 100644 pkg/tools/hardware/serial_test.go create mode 100644 pkg/tools/hardware/serial_unix.go create mode 100644 pkg/tools/hardware/serial_windows.go diff --git a/pkg/tools/facade_compat_test.go b/pkg/tools/facade_compat_test.go index 672554209..378462512 100644 --- a/pkg/tools/facade_compat_test.go +++ b/pkg/tools/facade_compat_test.go @@ -9,6 +9,9 @@ func TestFacadeConstructorsRemainAvailable(t *testing.T) { if NewSPITool() == nil { t.Fatal("NewSPITool should return a tool") } + if NewSerialTool() == nil { + t.Fatal("NewSerialTool should return a tool") + } if NewMessageTool() == nil { t.Fatal("NewMessageTool should return a tool") } diff --git a/pkg/tools/hardware/serial.go b/pkg/tools/hardware/serial.go new file mode 100644 index 000000000..3f55da362 --- /dev/null +++ b/pkg/tools/hardware/serial.go @@ -0,0 +1,332 @@ +package hardwaretools + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + "unicode/utf8" +) + +const ( + defaultSerialBaud = 115200 + defaultSerialDataBits = 8 + defaultSerialStopBits = 1 + defaultSerialTimeoutMS = 1000 + maxSerialPayloadBytes = 4096 + maxSerialReadBytes = 4096 +) + +type SerialTool struct{} + +type serialPortInfo struct { + Name string `json:"name"` + Path string `json:"path"` +} + +type serialConfig struct { + Port string + Baud int + DataBits int + Parity string + StopBits int +} + +func NewSerialTool() *SerialTool { + return &SerialTool{} +} + +func (t *SerialTool) Name() string { + return "serial" +} + +func (t *SerialTool) Description() string { + return "Interact with host serial ports. Actions: list (enumerate ports), read (receive bytes), write (send bytes with explicit confirmation). Supports Linux, macOS, and Windows." +} + +func (t *SerialTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"list", "read", "write"}, + "description": "Action to perform: list available serial ports, read bytes from a port, or write bytes to a port.", + }, + "port": map[string]any{ + "type": "string", + "description": "Serial port path or name, for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3. Required for read/write.", + }, + "baud": map[string]any{ + "type": "integer", + "description": "Baud rate. Default: 115200.", + }, + "data_bits": map[string]any{ + "type": "integer", + "description": "Data bits. Supported values: 5, 6, 7, 8. Default: 8.", + }, + "parity": map[string]any{ + "type": "string", + "enum": []string{"none", "even", "odd"}, + "description": "Parity mode. Default: none.", + }, + "stop_bits": map[string]any{ + "type": "integer", + "description": "Stop bits. Supported values: 1, 2. Default: 1.", + }, + "timeout_ms": map[string]any{ + "type": "integer", + "description": "Read/write timeout in milliseconds. Default: 1000.", + }, + "length": map[string]any{ + "type": "integer", + "description": "Number of bytes to read. Required for read. Range: 1-4096.", + }, + "data": map[string]any{ + "type": "array", + "items": map[string]any{"type": "integer"}, + "description": "Bytes to write, each in range 0-255. Required for write unless text is provided.", + }, + "text": map[string]any{ + "type": "string", + "description": "UTF-8 text to write. Required for write if data is omitted.", + }, + "confirm": map[string]any{ + "type": "boolean", + "description": "Must be true for write operations.", + }, + }, + "required": []string{"action"}, + } +} + +func (t *SerialTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + action, ok := args["action"].(string) + if !ok || strings.TrimSpace(action) == "" { + return ErrorResult("action is required") + } + + switch action { + case "list": + return t.list() + case "read": + return t.read(args) + case "write": + return t.write(args) + default: + return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, read, write)", action)) + } +} + +func (t *SerialTool) list() *ToolResult { + ports, err := serialListPorts() + if err != nil { + return ErrorResult(fmt.Sprintf("failed to list serial ports: %v", err)) + } + if len(ports) == 0 { + return SilentResult("No serial ports found on this host.") + } + + result, _ := json.MarshalIndent(map[string]any{ + "ports": ports, + "count": len(ports), + }, "", " ") + return SilentResult(string(result)) +} + +func (t *SerialTool) read(args map[string]any) *ToolResult { + cfg, errResult := parseSerialConfig(args) + if errResult != nil { + return errResult + } + + length := 0 + if v, ok := args["length"].(float64); ok { + length = int(v) + } + if length < 1 || length > maxSerialReadBytes { + return ErrorResult(fmt.Sprintf("length is required for read (1-%d)", maxSerialReadBytes)) + } + + timeout, errResult := parseSerialTimeout(args) + if errResult != nil { + return errResult + } + + data, err := serialRead(cfg, length, timeout) + if err != nil { + return ErrorResult(fmt.Sprintf("serial read failed on %s: %v", cfg.Port, err)) + } + + return SilentResult(formatSerialPayload("read", cfg, data, timeout)) +} + +func (t *SerialTool) write(args map[string]any) *ToolResult { + confirm, _ := args["confirm"].(bool) + if !confirm { + return ErrorResult( + "write operations require confirm: true. Please confirm with the user before sending bytes to a serial device.", + ) + } + + cfg, errResult := parseSerialConfig(args) + if errResult != nil { + return errResult + } + timeout, errResult := parseSerialTimeout(args) + if errResult != nil { + return errResult + } + payload, errResult := parseSerialWritePayload(args) + if errResult != nil { + return errResult + } + + written, err := serialWrite(cfg, payload, timeout) + if err != nil { + return ErrorResult(fmt.Sprintf("serial write failed on %s: %v", cfg.Port, err)) + } + + result, _ := json.MarshalIndent(map[string]any{ + "action": "write", + "port": cfg.Port, + "baud": cfg.Baud, + "data_bits": cfg.DataBits, + "parity": cfg.Parity, + "stop_bits": cfg.StopBits, + "timeout_ms": timeout.Milliseconds(), + "written": written, + "payload": serialPayloadSummary(payload), + }, "", " ") + return SilentResult(string(result)) +} + +func parseSerialConfig(args map[string]any) (serialConfig, *ToolResult) { + port, ok := args["port"].(string) + port = strings.TrimSpace(port) + if !ok || port == "" { + return serialConfig{}, ErrorResult("port is required (for example /dev/ttyUSB0, /dev/cu.usbserial-0001, or COM3)") + } + + cfg := serialConfig{ + Port: port, + Baud: defaultSerialBaud, + DataBits: defaultSerialDataBits, + Parity: "none", + StopBits: defaultSerialStopBits, + } + + if v, ok := args["baud"].(float64); ok { + cfg.Baud = int(v) + } + if cfg.Baud < 50 || cfg.Baud > 4000000 { + return serialConfig{}, ErrorResult("baud must be between 50 and 4000000") + } + + if v, ok := args["data_bits"].(float64); ok { + cfg.DataBits = int(v) + } + switch cfg.DataBits { + case 5, 6, 7, 8: + default: + return serialConfig{}, ErrorResult("data_bits must be one of 5, 6, 7, or 8") + } + + if v, ok := args["parity"].(string); ok && strings.TrimSpace(v) != "" { + cfg.Parity = strings.ToLower(strings.TrimSpace(v)) + } + switch cfg.Parity { + case "none", "even", "odd": + default: + return serialConfig{}, ErrorResult(`parity must be one of "none", "even", or "odd"`) + } + + if v, ok := args["stop_bits"].(float64); ok { + cfg.StopBits = int(v) + } + if cfg.StopBits != 1 && cfg.StopBits != 2 { + return serialConfig{}, ErrorResult("stop_bits must be 1 or 2") + } + + return cfg, nil +} + +func parseSerialTimeout(args map[string]any) (time.Duration, *ToolResult) { + timeoutMS := defaultSerialTimeoutMS + if v, ok := args["timeout_ms"].(float64); ok { + timeoutMS = int(v) + } + if timeoutMS < 1 || timeoutMS > 60000 { + return 0, ErrorResult("timeout_ms must be between 1 and 60000") + } + return time.Duration(timeoutMS) * time.Millisecond, nil +} + +func parseSerialWritePayload(args map[string]any) ([]byte, *ToolResult) { + if text, ok := args["text"].(string); ok && text != "" { + if !utf8.ValidString(text) { + return nil, ErrorResult("text must be valid UTF-8") + } + if len(text) > maxSerialPayloadBytes { + return nil, ErrorResult(fmt.Sprintf("text payload too large: maximum %d bytes", maxSerialPayloadBytes)) + } + return []byte(text), nil + } + + dataRaw, ok := args["data"].([]any) + if !ok || len(dataRaw) == 0 { + return nil, ErrorResult("write requires either text or data") + } + if len(dataRaw) > maxSerialPayloadBytes { + return nil, ErrorResult(fmt.Sprintf("data too long: maximum %d bytes", maxSerialPayloadBytes)) + } + + data := make([]byte, len(dataRaw)) + for i, v := range dataRaw { + f, ok := v.(float64) + if !ok { + return nil, ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) + } + b := int(f) + if b < 0 || b > 255 { + return nil, ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) + } + data[i] = byte(b) + } + + return data, nil +} + +func formatSerialPayload(action string, cfg serialConfig, data []byte, timeout time.Duration) string { + result, _ := json.MarshalIndent(map[string]any{ + "action": action, + "port": cfg.Port, + "baud": cfg.Baud, + "data_bits": cfg.DataBits, + "parity": cfg.Parity, + "stop_bits": cfg.StopBits, + "timeout_ms": timeout.Milliseconds(), + "payload": serialPayloadSummary(data), + }, "", " ") + return string(result) +} + +func serialPayloadSummary(data []byte) map[string]any { + hexValues := make([]string, len(data)) + intValues := make([]int, len(data)) + for i, b := range data { + hexValues[i] = fmt.Sprintf("0x%02x", b) + intValues[i] = int(b) + } + + summary := map[string]any{ + "length": len(data), + "bytes": intValues, + "hex": hexValues, + } + if utf8.Valid(data) { + summary["text"] = string(data) + } + return summary +} diff --git a/pkg/tools/hardware/serial_darwin.go b/pkg/tools/hardware/serial_darwin.go new file mode 100644 index 000000000..f6cf30db4 --- /dev/null +++ b/pkg/tools/hardware/serial_darwin.go @@ -0,0 +1,19 @@ +//go:build darwin + +package hardwaretools + +import "golang.org/x/sys/unix" + +func serialGetTermios(fd int) (*unix.Termios, error) { + return unix.IoctlGetTermios(fd, unix.TIOCGETA) +} + +func serialSetSpeed(tio *unix.Termios, speed uint32) error { + tio.Ispeed = speed + tio.Ospeed = speed + return nil +} + +func serialSetTermios(fd int, tio *unix.Termios) error { + return unix.IoctlSetTermios(fd, unix.TIOCSETA, tio) +} diff --git a/pkg/tools/hardware/serial_linux.go b/pkg/tools/hardware/serial_linux.go new file mode 100644 index 000000000..bad3e4cb8 --- /dev/null +++ b/pkg/tools/hardware/serial_linux.go @@ -0,0 +1,19 @@ +//go:build linux + +package hardwaretools + +import "golang.org/x/sys/unix" + +func serialGetTermios(fd int) (*unix.Termios, error) { + return unix.IoctlGetTermios(fd, unix.TCGETS) +} + +func serialSetSpeed(tio *unix.Termios, speed uint32) error { + tio.Ispeed = speed + tio.Ospeed = speed + return nil +} + +func serialSetTermios(fd int, tio *unix.Termios) error { + return unix.IoctlSetTermios(fd, unix.TCSETS, tio) +} diff --git a/pkg/tools/hardware/serial_other.go b/pkg/tools/hardware/serial_other.go new file mode 100644 index 000000000..f57ce2fa3 --- /dev/null +++ b/pkg/tools/hardware/serial_other.go @@ -0,0 +1,20 @@ +//go:build !linux && !darwin && !windows + +package hardwaretools + +import ( + "fmt" + "time" +) + +func serialListPorts() ([]serialPortInfo, error) { + return nil, nil +} + +func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) { + return nil, fmt.Errorf("serial is not supported on this platform") +} + +func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) { + return 0, fmt.Errorf("serial is not supported on this platform") +} diff --git a/pkg/tools/hardware/serial_test.go b/pkg/tools/hardware/serial_test.go new file mode 100644 index 000000000..68657dffd --- /dev/null +++ b/pkg/tools/hardware/serial_test.go @@ -0,0 +1,66 @@ +package hardwaretools + +import ( + "testing" + "time" +) + +func TestParseSerialConfig(t *testing.T) { + cfg, errResult := parseSerialConfig(map[string]any{ + "port": "COM3", + "baud": float64(9600), + "data_bits": float64(7), + "parity": "even", + "stop_bits": float64(2), + }) + if errResult != nil { + t.Fatalf("parseSerialConfig() unexpected error = %v", errResult.ForLLM) + } + + if cfg.Port != "COM3" || cfg.Baud != 9600 || cfg.DataBits != 7 || cfg.Parity != "even" || cfg.StopBits != 2 { + t.Fatalf("parseSerialConfig() = %#v", cfg) + } +} + +func TestParseSerialConfigRejectsInvalidParity(t *testing.T) { + _, errResult := parseSerialConfig(map[string]any{ + "port": "/dev/ttyUSB0", + "parity": "mark", + }) + if errResult == nil { + t.Fatal("expected invalid parity to fail") + } +} + +func TestParseSerialTimeout(t *testing.T) { + timeout, errResult := parseSerialTimeout(map[string]any{ + "timeout_ms": float64(2500), + }) + if errResult != nil { + t.Fatalf("parseSerialTimeout() unexpected error = %v", errResult.ForLLM) + } + if timeout != 2500*time.Millisecond { + t.Fatalf("timeout = %v, want 2500ms", timeout) + } +} + +func TestParseSerialWritePayloadSupportsText(t *testing.T) { + data, errResult := parseSerialWritePayload(map[string]any{ + "text": "AT\r\n", + }) + if errResult != nil { + t.Fatalf("parseSerialWritePayload() unexpected error = %v", errResult.ForLLM) + } + if string(data) != "AT\r\n" { + t.Fatalf("payload = %q, want %q", string(data), "AT\r\n") + } +} + +func TestParseSerialWritePayloadRejectsOutOfRangeByte(t *testing.T) { + _, errResult := parseSerialWritePayload(map[string]any{ + "data": []any{float64(256)}, + }) + if errResult == nil { + t.Fatal("expected payload validation failure") + } +} diff --git a/pkg/tools/hardware/serial_unix.go b/pkg/tools/hardware/serial_unix.go new file mode 100644 index 000000000..caef7e0e8 --- /dev/null +++ b/pkg/tools/hardware/serial_unix.go @@ -0,0 +1,263 @@ +//go:build linux || darwin + +package hardwaretools + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "golang.org/x/sys/unix" +) + +func serialListPorts() ([]serialPortInfo, error) { + patterns := []string{ + "/dev/ttyS*", + "/dev/ttyUSB*", + "/dev/ttyACM*", + "/dev/ttyAMA*", + "/dev/rfcomm*", + "/dev/tty.*", + "/dev/cu.*", + } + + seen := make(map[string]struct{}) + ports := make([]serialPortInfo, 0) + for _, pattern := range patterns { + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + for _, match := range matches { + if _, ok := seen[match]; ok { + continue + } + info, err := os.Stat(match) + if err != nil || info.IsDir() { + continue + } + seen[match] = struct{}{} + ports = append(ports, serialPortInfo{ + Name: filepath.Base(match), + Path: match, + }) + } + } + + sort.Slice(ports, func(i, j int) bool { + return ports[i].Path < ports[j].Path + }) + return ports, nil +} + +func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) { + fd, err := openAndConfigureSerialPort(cfg) + if err != nil { + return nil, err + } + defer unix.Close(fd) + + buf := make([]byte, length) + total := 0 + deadline := time.Now().Add(timeout) + + for total < length { + remaining := time.Until(deadline) + if remaining <= 0 { + break + } + + n, err := pollRead(fd, buf[total:], remaining) + if err != nil { + return nil, err + } + if n == 0 { + break + } + total += n + } + + return buf[:total], nil +} + +func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) { + fd, err := openAndConfigureSerialPort(cfg) + if err != nil { + return 0, err + } + defer unix.Close(fd) + + total := 0 + deadline := time.Now().Add(timeout) + for total < len(data) { + remaining := time.Until(deadline) + if remaining <= 0 { + return total, fmt.Errorf("timeout while writing serial data") + } + + n, err := pollWrite(fd, data[total:], remaining) + if err != nil { + return total, err + } + if n == 0 { + return total, fmt.Errorf("serial port accepted zero bytes") + } + total += n + } + + return total, nil +} + +func openAndConfigureSerialPort(cfg serialConfig) (int, error) { + fd, err := unix.Open(normalizeUnixSerialPath(cfg.Port), unix.O_RDWR|unix.O_NOCTTY|unix.O_NONBLOCK, 0) + if err != nil { + return -1, err + } + + if err := unix.SetNonblock(fd, false); err != nil { + unix.Close(fd) + return -1, err + } + + if err := configureUnixSerialPort(fd, cfg); err != nil { + unix.Close(fd) + return -1, err + } + + return fd, nil +} + +func configureUnixSerialPort(fd int, cfg serialConfig) error { + tio, err := serialGetTermios(fd) + if err != nil { + return err + } + + tio.Iflag = 0 + tio.Oflag = 0 + tio.Lflag = 0 + tio.Cflag = unix.CREAD | unix.CLOCAL + tio.Cc[unix.VMIN] = 0 + tio.Cc[unix.VTIME] = 0 + + switch cfg.DataBits { + case 5: + tio.Cflag |= unix.CS5 + case 6: + tio.Cflag |= unix.CS6 + case 7: + tio.Cflag |= unix.CS7 + default: + tio.Cflag |= unix.CS8 + } + + switch cfg.Parity { + case "even": + tio.Cflag |= unix.PARENB + case "odd": + tio.Cflag |= unix.PARENB | unix.PARODD + } + + if cfg.StopBits == 2 { + tio.Cflag |= unix.CSTOPB + } + + speed, err := serialBaudToUnix(cfg.Baud) + if err != nil { + return err + } + if err := serialSetSpeed(tio, speed); err != nil { + return err + } + + return serialSetTermios(fd, tio) +} + +func serialBaudToUnix(baud int) (uint32, error) { + switch baud { + case 50: + return unix.B50, nil + case 75: + return unix.B75, nil + case 110: + return unix.B110, nil + case 134: + return unix.B134, nil + case 150: + return unix.B150, nil + case 200: + return unix.B200, nil + case 300: + return unix.B300, nil + case 600: + return unix.B600, nil + case 1200: + return unix.B1200, nil + case 1800: + return unix.B1800, nil + case 2400: + return unix.B2400, nil + case 4800: + return unix.B4800, nil + case 9600: + return unix.B9600, nil + case 19200: + return unix.B19200, nil + case 38400: + return unix.B38400, nil + case 57600: + return unix.B57600, nil + case 115200: + return unix.B115200, nil + case 230400: + return unix.B230400, nil + default: + return 0, fmt.Errorf("unsupported baud rate on this platform: %d", baud) + } +} + +func pollRead(fd int, dst []byte, timeout time.Duration) (int, error) { + pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN}} + n, err := unix.Poll(pfd, durationToPollTimeout(timeout)) + if err != nil { + return 0, err + } + if n == 0 { + return 0, nil + } + return unix.Read(fd, dst) +} + +func pollWrite(fd int, src []byte, timeout time.Duration) (int, error) { + pfd := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLOUT}} + n, err := unix.Poll(pfd, durationToPollTimeout(timeout)) + if err != nil { + return 0, err + } + if n == 0 { + return 0, nil + } + return unix.Write(fd, src) +} + +func durationToPollTimeout(timeout time.Duration) int { + if timeout <= 0 { + return 0 + } + ms := int(timeout / time.Millisecond) + if ms == 0 { + return 1 + } + return ms +} + +func normalizeUnixSerialPath(port string) string { + trimmed := strings.TrimSpace(port) + if strings.HasPrefix(trimmed, "/dev/") { + return trimmed + } + return "/dev/" + trimmed +} diff --git a/pkg/tools/hardware/serial_windows.go b/pkg/tools/hardware/serial_windows.go new file mode 100644 index 000000000..c4d119ac4 --- /dev/null +++ b/pkg/tools/hardware/serial_windows.go @@ -0,0 +1,214 @@ +//go:build windows + +package hardwaretools + +import ( + "fmt" + "path/filepath" + "sort" + "strings" + "time" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procGetCommState = kernel32.NewProc("GetCommState") + procSetCommState = kernel32.NewProc("SetCommState") + procSetCommTimeouts = kernel32.NewProc("SetCommTimeouts") + procPurgeComm = kernel32.NewProc("PurgeComm") +) + +const ( + purgeTxClear = 0x0004 + purgeRxClear = 0x0008 +) + +type dcb struct { + DCBlength uint32 + BaudRate uint32 + Flags uint32 + Reserved uint16 + XonLim uint16 + XoffLim uint16 + ByteSize byte + Parity byte + StopBits byte + XonChar byte + XoffChar byte + ErrorChar byte + EofChar byte + EvtChar byte + wReserved1 uint16 +} + +type commTimeouts struct { + ReadIntervalTimeout uint32 + ReadTotalTimeoutMultiplier uint32 + ReadTotalTimeoutConstant uint32 + WriteTotalTimeoutMultiplier uint32 + WriteTotalTimeoutConstant uint32 +} + +func serialListPorts() ([]serialPortInfo, error) { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, `HARDWARE\DEVICEMAP\SERIALCOMM`, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return nil, nil + } + return nil, err + } + defer key.Close() + + names, err := key.ReadValueNames(-1) + if err != nil { + return nil, err + } + + ports := make([]serialPortInfo, 0, len(names)) + seen := make(map[string]struct{}) + for _, name := range names { + value, _, err := key.GetStringValue(name) + if err != nil { + continue + } + portName := strings.TrimSpace(value) + if portName == "" { + continue + } + normalized := strings.ToUpper(portName) + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + ports = append(ports, serialPortInfo{ + Name: normalized, + Path: normalized, + }) + } + + sort.Slice(ports, func(i, j int) bool { + return ports[i].Path < ports[j].Path + }) + return ports, nil +} + +func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) { + handle, err := openAndConfigureWindowsSerial(cfg, timeout) + if err != nil { + return nil, err + } + defer windows.CloseHandle(handle) + + buf := make([]byte, length) + var read uint32 + if err := windows.ReadFile(handle, buf, &read, nil); err != nil { + return nil, err + } + return buf[:read], nil +} + +func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) { + handle, err := openAndConfigureWindowsSerial(cfg, timeout) + if err != nil { + return 0, err + } + defer windows.CloseHandle(handle) + + var written uint32 + if err := windows.WriteFile(handle, data, &written, nil); err != nil { + return int(written), err + } + return int(written), nil +} + +func openAndConfigureWindowsSerial(cfg serialConfig, timeout time.Duration) (windows.Handle, error) { + path := normalizeWindowsSerialPath(cfg.Port) + handle, err := windows.CreateFile( + windows.StringToUTF16Ptr(path), + windows.GENERIC_READ|windows.GENERIC_WRITE, + 0, + nil, + windows.OPEN_EXISTING, + 0, + 0, + ) + if err != nil { + return 0, err + } + + if err := configureWindowsSerialPort(handle, cfg, timeout); err != nil { + windows.CloseHandle(handle) + return 0, err + } + return handle, nil +} + +func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout time.Duration) error { + state := dcb{DCBlength: uint32(unsafe.Sizeof(dcb{}))} + r1, _, err := procGetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state))) + if r1 == 0 { + return err + } + + state.BaudRate = uint32(cfg.Baud) + state.ByteSize = byte(cfg.DataBits) + state.Flags |= 0x00000001 // fBinary + + switch cfg.Parity { + case "even": + state.Parity = 2 + state.Flags |= 0x00000002 // fParity + case "odd": + state.Parity = 1 + state.Flags |= 0x00000002 // fParity + default: + state.Parity = 0 + state.Flags &^= 0x00000002 + } + + switch cfg.StopBits { + case 2: + state.StopBits = 2 + default: + state.StopBits = 0 + } + + r1, _, err = procSetCommState.Call(uintptr(handle), uintptr(unsafe.Pointer(&state))) + if r1 == 0 { + return err + } + + timeoutMS := uint32(timeout / time.Millisecond) + if timeoutMS == 0 { + timeoutMS = 1 + } + timeouts := commTimeouts{ + ReadIntervalTimeout: timeoutMS, + ReadTotalTimeoutConstant: timeoutMS, + WriteTotalTimeoutConstant: timeoutMS, + ReadTotalTimeoutMultiplier: 0, + WriteTotalTimeoutMultiplier: 0, + } + r1, _, err = procSetCommTimeouts.Call(uintptr(handle), uintptr(unsafe.Pointer(&timeouts))) + if r1 == 0 { + return err + } + + procPurgeComm.Call(uintptr(handle), uintptr(purgeRxClear|purgeTxClear)) + return nil +} + +func normalizeWindowsSerialPath(port string) string { + trimmed := strings.ToUpper(strings.TrimSpace(port)) + if strings.HasPrefix(trimmed, `\\.\`) { + return trimmed + } + if filepath.VolumeName(trimmed) != "" { + return trimmed + } + return `\\.\` + trimmed +} diff --git a/pkg/tools/hardware_facade.go b/pkg/tools/hardware_facade.go index f55d152cf..b505c5a48 100644 --- a/pkg/tools/hardware_facade.go +++ b/pkg/tools/hardware_facade.go @@ -3,8 +3,9 @@ package tools import hardwaretools "github.com/sipeed/picoclaw/pkg/tools/hardware" type ( - I2CTool = hardwaretools.I2CTool - SPITool = hardwaretools.SPITool + I2CTool = hardwaretools.I2CTool + SerialTool = hardwaretools.SerialTool + SPITool = hardwaretools.SPITool ) func NewI2CTool() *I2CTool { @@ -14,3 +15,7 @@ func NewI2CTool() *I2CTool { func NewSPITool() *SPITool { return hardwaretools.NewSPITool() } + +func NewSerialTool() *SerialTool { + return hardwaretools.NewSerialTool() +}