hardware/serial: improve unix cancellation and timeout polling

This commit is contained in:
SiYue-ZO
2026-04-28 12:57:09 +08:00
parent 338fa258b3
commit 1f0a5f4eda
6 changed files with 260 additions and 24 deletions
+16 -6
View File
@@ -19,6 +19,7 @@ const (
defaultSerialTimeoutMS = 1000
maxSerialPayloadBytes = 4096
maxSerialReadBytes = 4096
serialPollInterval = 100 * time.Millisecond
)
var (
@@ -123,9 +124,9 @@ func (t *SerialTool) Execute(ctx context.Context, args map[string]any) *ToolResu
case "list":
return t.list()
case "read":
return t.read(args)
return t.read(ctx, args)
case "write":
return t.write(args)
return t.write(ctx, args)
default:
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, read, write)", action))
}
@@ -147,7 +148,7 @@ func (t *SerialTool) list() *ToolResult {
return SilentResult(string(result))
}
func (t *SerialTool) read(args map[string]any) *ToolResult {
func (t *SerialTool) read(ctx context.Context, args map[string]any) *ToolResult {
cfg, errResult := parseSerialConfig(args)
if errResult != nil {
return errResult
@@ -166,7 +167,7 @@ func (t *SerialTool) read(args map[string]any) *ToolResult {
return errResult
}
data, err := serialRead(cfg, length, timeout)
data, err := serialRead(ctx, cfg, length, timeout)
if err != nil {
return ErrorResult(fmt.Sprintf("serial read failed on %s: %v", cfg.Port, err))
}
@@ -174,7 +175,7 @@ func (t *SerialTool) read(args map[string]any) *ToolResult {
return SilentResult(formatSerialPayload("read", cfg, data, timeout))
}
func (t *SerialTool) write(args map[string]any) *ToolResult {
func (t *SerialTool) write(ctx context.Context, args map[string]any) *ToolResult {
confirm, _ := args["confirm"].(bool)
if !confirm {
return ErrorResult(
@@ -195,7 +196,7 @@ func (t *SerialTool) write(args map[string]any) *ToolResult {
return errResult
}
written, err := serialWrite(cfg, payload, timeout)
written, err := serialWrite(ctx, cfg, payload, timeout)
if err != nil {
return ErrorResult(fmt.Sprintf("serial write failed on %s: %v", cfg.Port, err))
}
@@ -403,3 +404,12 @@ func validateSerialBaud(baud int) error {
return nil
}
func serialContextErr(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
+3 -2
View File
@@ -3,6 +3,7 @@
package hardwaretools
import (
"context"
"fmt"
"time"
)
@@ -11,10 +12,10 @@ func serialListPorts() ([]serialPortInfo, error) {
return nil, fmt.Errorf("serial is not supported on this platform")
}
func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
return nil, fmt.Errorf("serial is not supported on this platform")
}
func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
return 0, fmt.Errorf("serial is not supported on this platform")
}
+36
View File
@@ -1,6 +1,7 @@
package hardwaretools
import (
"context"
"runtime"
"strings"
"testing"
@@ -93,6 +94,41 @@ func TestValidateSerialBaud(t *testing.T) {
}
}
func TestSerialReadCanceledBeforeOpen(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
port := "/dev/ttyUSB0"
if runtime.GOOS == "windows" {
port = "COM3"
}
_, err := serialRead(ctx, serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1}, 1, time.Second)
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
t.Fatalf("serialRead() error = %v, want context canceled", err)
}
}
func TestSerialWriteCanceledBeforeOpen(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
port := "/dev/ttyUSB0"
if runtime.GOOS == "windows" {
port = "COM3"
}
_, err := serialWrite(
ctx,
serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1},
[]byte("AT"),
time.Second,
)
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
t.Fatalf("serialWrite() error = %v, want context canceled", err)
}
}
func TestParseSerialConfigRejectsUnsafePortPaths(t *testing.T) {
tests := []string{
"../../../etc/passwd",
+46 -14
View File
@@ -3,6 +3,7 @@
package hardwaretools
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -12,6 +13,14 @@ import (
"golang.org/x/sys/unix"
)
var (
unixSerialNow = time.Now
unixSerialOpenPort = openAndConfigureSerialPort
unixSerialClosePort = unix.Close
unixSerialPollRead = pollRead
unixSerialPollWrite = pollWrite
)
func serialListPorts() ([]serialPortInfo, error) {
patterns := []string{
"/dev/ttyS*",
@@ -52,29 +61,37 @@ func serialListPorts() ([]serialPortInfo, error) {
return ports, nil
}
func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
fd, err := openAndConfigureSerialPort(cfg)
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
if err := serialContextErr(ctx); err != nil {
return nil, err
}
fd, err := unixSerialOpenPort(cfg)
if err != nil {
return nil, err
}
defer unix.Close(fd)
defer unixSerialClosePort(fd)
buf := make([]byte, length)
total := 0
deadline := time.Now().Add(timeout)
deadline := unixSerialNow().Add(timeout)
for total < length {
remaining := time.Until(deadline)
if err := serialContextErr(ctx); err != nil {
return nil, err
}
remaining := deadline.Sub(unixSerialNow())
if remaining <= 0 {
break
}
n, err := pollRead(fd, buf[total:], remaining)
n, err := unixSerialPollRead(fd, buf[total:], minSerialPollTimeout(remaining))
if err != nil {
return nil, err
}
if n == 0 {
break
continue
}
total += n
}
@@ -82,27 +99,35 @@ func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, er
return buf[:total], nil
}
func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
fd, err := openAndConfigureSerialPort(cfg)
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
if err := serialContextErr(ctx); err != nil {
return 0, err
}
fd, err := unixSerialOpenPort(cfg)
if err != nil {
return 0, err
}
defer unix.Close(fd)
defer unixSerialClosePort(fd)
total := 0
deadline := time.Now().Add(timeout)
deadline := unixSerialNow().Add(timeout)
for total < len(data) {
remaining := time.Until(deadline)
if err := serialContextErr(ctx); err != nil {
return total, err
}
remaining := deadline.Sub(unixSerialNow())
if remaining <= 0 {
return total, fmt.Errorf("timeout while writing serial data")
}
n, err := pollWrite(fd, data[total:], remaining)
n, err := unixSerialPollWrite(fd, data[total:], minSerialPollTimeout(remaining))
if err != nil {
return total, err
}
if n == 0 {
return total, fmt.Errorf("serial port accepted zero bytes")
continue
}
total += n
}
@@ -252,3 +277,10 @@ func durationToPollTimeout(timeout time.Duration) int {
}
return ms
}
func minSerialPollTimeout(timeout time.Duration) time.Duration {
if timeout > serialPollInterval {
return serialPollInterval
}
return timeout
}
+140
View File
@@ -0,0 +1,140 @@
//go:build linux || darwin
package hardwaretools
import (
"context"
"errors"
"testing"
"time"
)
func stubUnixSerialIO(t *testing.T, now *time.Time) {
t.Helper()
prevNow := unixSerialNow
prevOpen := unixSerialOpenPort
prevClose := unixSerialClosePort
prevPollRead := unixSerialPollRead
prevPollWrite := unixSerialPollWrite
unixSerialNow = func() time.Time {
return *now
}
unixSerialOpenPort = func(cfg serialConfig) (int, error) {
return 42, nil
}
unixSerialClosePort = func(fd int) error {
return nil
}
unixSerialPollRead = prevPollRead
unixSerialPollWrite = prevPollWrite
t.Cleanup(func() {
unixSerialNow = prevNow
unixSerialOpenPort = prevOpen
unixSerialClosePort = prevClose
unixSerialPollRead = prevPollRead
unixSerialPollWrite = prevPollWrite
})
}
func TestSerialReadWaitsPastEmptyPollsUntilDeadline(t *testing.T) {
now := time.Unix(0, 0)
stubUnixSerialIO(t, &now)
pollCalls := 0
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
pollCalls++
if timeout > serialPollInterval {
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
}
now = now.Add(timeout)
if pollCalls < 4 {
return 0, nil
}
return copy(dst, []byte("OK")), nil
}
got, err := serialRead(context.Background(), serialConfig{}, 2, 500*time.Millisecond)
if err != nil {
t.Fatalf("serialRead() error = %v", err)
}
if string(got) != "OK" {
t.Fatalf("serialRead() = %q, want %q", got, "OK")
}
if pollCalls != 4 {
t.Fatalf("poll calls = %d, want 4", pollCalls)
}
}
func TestSerialReadReturnsPromptlyOnContextCancelBetweenPolls(t *testing.T) {
now := time.Unix(0, 0)
stubUnixSerialIO(t, &now)
ctx, cancel := context.WithCancel(context.Background())
pollCalls := 0
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
pollCalls++
now = now.Add(timeout)
cancel()
return 0, nil
}
_, err := serialRead(ctx, serialConfig{}, 1, time.Second)
if !errors.Is(err, context.Canceled) {
t.Fatalf("serialRead() error = %v, want context canceled", err)
}
if pollCalls != 1 {
t.Fatalf("poll calls = %d, want 1", pollCalls)
}
}
func TestSerialWriteWaitsPastEmptyPollsUntilReady(t *testing.T) {
now := time.Unix(0, 0)
stubUnixSerialIO(t, &now)
pollCalls := 0
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
pollCalls++
if timeout > serialPollInterval {
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
}
now = now.Add(timeout)
switch pollCalls {
case 1, 2:
return 0, nil
default:
return 1, nil
}
}
written, err := serialWrite(context.Background(), serialConfig{}, []byte("OK"), 500*time.Millisecond)
if err != nil {
t.Fatalf("serialWrite() error = %v", err)
}
if written != 2 {
t.Fatalf("serialWrite() wrote %d bytes, want 2", written)
}
if pollCalls != 4 {
t.Fatalf("poll calls = %d, want 4", pollCalls)
}
}
func TestSerialWriteTimesOutAfterRepeatedEmptyPolls(t *testing.T) {
now := time.Unix(0, 0)
stubUnixSerialIO(t, &now)
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
now = now.Add(timeout)
return 0, nil
}
written, err := serialWrite(context.Background(), serialConfig{}, []byte("A"), 250*time.Millisecond)
if err == nil || err.Error() != "timeout while writing serial data" {
t.Fatalf("serialWrite() error = %v, want timeout", err)
}
if written != 0 {
t.Fatalf("serialWrite() wrote %d bytes, want 0", written)
}
}
+19 -2
View File
@@ -3,6 +3,7 @@
package hardwaretools
import (
"context"
"fmt"
"sort"
"strings"
@@ -95,13 +96,21 @@ func serialListPorts() ([]serialPortInfo, error) {
return ports, nil
}
func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
if err := serialContextErr(ctx); err != nil {
return nil, err
}
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
if err != nil {
return nil, err
}
defer windows.CloseHandle(handle)
if err := serialContextErr(ctx); err != nil {
return nil, err
}
buf := make([]byte, length)
var read uint32
if err := windows.ReadFile(handle, buf, &read, nil); err != nil {
@@ -110,13 +119,21 @@ func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, er
return buf[:read], nil
}
func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
if err := serialContextErr(ctx); err != nil {
return 0, err
}
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
if err != nil {
return 0, err
}
defer windows.CloseHandle(handle)
if err := serialContextErr(ctx); err != nil {
return 0, err
}
var written uint32
if err := windows.WriteFile(handle, data, &written, nil); err != nil {
return int(written), err