mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
hardware/serial: improve unix cancellation and timeout polling
This commit is contained in:
@@ -19,6 +19,7 @@ const (
|
||||
defaultSerialTimeoutMS = 1000
|
||||
maxSerialPayloadBytes = 4096
|
||||
maxSerialReadBytes = 4096
|
||||
serialPollInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -123,9 +124,9 @@ func (t *SerialTool) Execute(ctx context.Context, args map[string]any) *ToolResu
|
||||
case "list":
|
||||
return t.list()
|
||||
case "read":
|
||||
return t.read(args)
|
||||
return t.read(ctx, args)
|
||||
case "write":
|
||||
return t.write(args)
|
||||
return t.write(ctx, args)
|
||||
default:
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, read, write)", action))
|
||||
}
|
||||
@@ -147,7 +148,7 @@ func (t *SerialTool) list() *ToolResult {
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
|
||||
func (t *SerialTool) read(args map[string]any) *ToolResult {
|
||||
func (t *SerialTool) read(ctx context.Context, args map[string]any) *ToolResult {
|
||||
cfg, errResult := parseSerialConfig(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
@@ -166,7 +167,7 @@ func (t *SerialTool) read(args map[string]any) *ToolResult {
|
||||
return errResult
|
||||
}
|
||||
|
||||
data, err := serialRead(cfg, length, timeout)
|
||||
data, err := serialRead(ctx, cfg, length, timeout)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("serial read failed on %s: %v", cfg.Port, err))
|
||||
}
|
||||
@@ -174,7 +175,7 @@ func (t *SerialTool) read(args map[string]any) *ToolResult {
|
||||
return SilentResult(formatSerialPayload("read", cfg, data, timeout))
|
||||
}
|
||||
|
||||
func (t *SerialTool) write(args map[string]any) *ToolResult {
|
||||
func (t *SerialTool) write(ctx context.Context, args map[string]any) *ToolResult {
|
||||
confirm, _ := args["confirm"].(bool)
|
||||
if !confirm {
|
||||
return ErrorResult(
|
||||
@@ -195,7 +196,7 @@ func (t *SerialTool) write(args map[string]any) *ToolResult {
|
||||
return errResult
|
||||
}
|
||||
|
||||
written, err := serialWrite(cfg, payload, timeout)
|
||||
written, err := serialWrite(ctx, cfg, payload, timeout)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("serial write failed on %s: %v", cfg.Port, err))
|
||||
}
|
||||
@@ -403,3 +404,12 @@ func validateSerialBaud(baud int) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serialContextErr(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
@@ -11,10 +12,10 @@ func serialListPorts() ([]serialPortInfo, error) {
|
||||
return nil, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
|
||||
func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
return nil, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
|
||||
func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
return 0, fmt.Errorf("serial is not supported on this platform")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -93,6 +94,41 @@ func TestValidateSerialBaud(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialReadCanceledBeforeOpen(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, err := serialRead(ctx, serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1}, 1, time.Second)
|
||||
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
|
||||
t.Fatalf("serialRead() error = %v, want context canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteCanceledBeforeOpen(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
port := "/dev/ttyUSB0"
|
||||
if runtime.GOOS == "windows" {
|
||||
port = "COM3"
|
||||
}
|
||||
|
||||
_, err := serialWrite(
|
||||
ctx,
|
||||
serialConfig{Port: port, Baud: 115200, DataBits: 8, Parity: "none", StopBits: 1},
|
||||
[]byte("AT"),
|
||||
time.Second,
|
||||
)
|
||||
if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) {
|
||||
t.Fatalf("serialWrite() error = %v, want context canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSerialConfigRejectsUnsafePortPaths(t *testing.T) {
|
||||
tests := []string{
|
||||
"../../../etc/passwd",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -12,6 +13,14 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var (
|
||||
unixSerialNow = time.Now
|
||||
unixSerialOpenPort = openAndConfigureSerialPort
|
||||
unixSerialClosePort = unix.Close
|
||||
unixSerialPollRead = pollRead
|
||||
unixSerialPollWrite = pollWrite
|
||||
)
|
||||
|
||||
func serialListPorts() ([]serialPortInfo, error) {
|
||||
patterns := []string{
|
||||
"/dev/ttyS*",
|
||||
@@ -52,29 +61,37 @@ func serialListPorts() ([]serialPortInfo, error) {
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
fd, err := openAndConfigureSerialPort(cfg)
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fd, err := unixSerialOpenPort(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
defer unixSerialClosePort(fd)
|
||||
|
||||
buf := make([]byte, length)
|
||||
total := 0
|
||||
deadline := time.Now().Add(timeout)
|
||||
deadline := unixSerialNow().Add(timeout)
|
||||
|
||||
for total < length {
|
||||
remaining := time.Until(deadline)
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remaining := deadline.Sub(unixSerialNow())
|
||||
if remaining <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := pollRead(fd, buf[total:], remaining)
|
||||
n, err := unixSerialPollRead(fd, buf[total:], minSerialPollTimeout(remaining))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n == 0 {
|
||||
break
|
||||
continue
|
||||
}
|
||||
total += n
|
||||
}
|
||||
@@ -82,27 +99,35 @@ func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, er
|
||||
return buf[:total], nil
|
||||
}
|
||||
|
||||
func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
fd, err := openAndConfigureSerialPort(cfg)
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
fd, err := unixSerialOpenPort(cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer unix.Close(fd)
|
||||
defer unixSerialClosePort(fd)
|
||||
|
||||
total := 0
|
||||
deadline := time.Now().Add(timeout)
|
||||
deadline := unixSerialNow().Add(timeout)
|
||||
for total < len(data) {
|
||||
remaining := time.Until(deadline)
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
remaining := deadline.Sub(unixSerialNow())
|
||||
if remaining <= 0 {
|
||||
return total, fmt.Errorf("timeout while writing serial data")
|
||||
}
|
||||
|
||||
n, err := pollWrite(fd, data[total:], remaining)
|
||||
n, err := unixSerialPollWrite(fd, data[total:], minSerialPollTimeout(remaining))
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
if n == 0 {
|
||||
return total, fmt.Errorf("serial port accepted zero bytes")
|
||||
continue
|
||||
}
|
||||
total += n
|
||||
}
|
||||
@@ -252,3 +277,10 @@ func durationToPollTimeout(timeout time.Duration) int {
|
||||
}
|
||||
return ms
|
||||
}
|
||||
|
||||
func minSerialPollTimeout(timeout time.Duration) time.Duration {
|
||||
if timeout > serialPollInterval {
|
||||
return serialPollInterval
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func stubUnixSerialIO(t *testing.T, now *time.Time) {
|
||||
t.Helper()
|
||||
|
||||
prevNow := unixSerialNow
|
||||
prevOpen := unixSerialOpenPort
|
||||
prevClose := unixSerialClosePort
|
||||
prevPollRead := unixSerialPollRead
|
||||
prevPollWrite := unixSerialPollWrite
|
||||
|
||||
unixSerialNow = func() time.Time {
|
||||
return *now
|
||||
}
|
||||
unixSerialOpenPort = func(cfg serialConfig) (int, error) {
|
||||
return 42, nil
|
||||
}
|
||||
unixSerialClosePort = func(fd int) error {
|
||||
return nil
|
||||
}
|
||||
unixSerialPollRead = prevPollRead
|
||||
unixSerialPollWrite = prevPollWrite
|
||||
|
||||
t.Cleanup(func() {
|
||||
unixSerialNow = prevNow
|
||||
unixSerialOpenPort = prevOpen
|
||||
unixSerialClosePort = prevClose
|
||||
unixSerialPollRead = prevPollRead
|
||||
unixSerialPollWrite = prevPollWrite
|
||||
})
|
||||
}
|
||||
|
||||
func TestSerialReadWaitsPastEmptyPollsUntilDeadline(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
pollCalls := 0
|
||||
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
if timeout > serialPollInterval {
|
||||
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
|
||||
}
|
||||
now = now.Add(timeout)
|
||||
if pollCalls < 4 {
|
||||
return 0, nil
|
||||
}
|
||||
return copy(dst, []byte("OK")), nil
|
||||
}
|
||||
|
||||
got, err := serialRead(context.Background(), serialConfig{}, 2, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("serialRead() error = %v", err)
|
||||
}
|
||||
if string(got) != "OK" {
|
||||
t.Fatalf("serialRead() = %q, want %q", got, "OK")
|
||||
}
|
||||
if pollCalls != 4 {
|
||||
t.Fatalf("poll calls = %d, want 4", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialReadReturnsPromptlyOnContextCancelBetweenPolls(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pollCalls := 0
|
||||
unixSerialPollRead = func(fd int, dst []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
now = now.Add(timeout)
|
||||
cancel()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
_, err := serialRead(ctx, serialConfig{}, 1, time.Second)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("serialRead() error = %v, want context canceled", err)
|
||||
}
|
||||
if pollCalls != 1 {
|
||||
t.Fatalf("poll calls = %d, want 1", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteWaitsPastEmptyPollsUntilReady(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
pollCalls := 0
|
||||
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
|
||||
pollCalls++
|
||||
if timeout > serialPollInterval {
|
||||
t.Fatalf("poll timeout = %v, want <= %v", timeout, serialPollInterval)
|
||||
}
|
||||
now = now.Add(timeout)
|
||||
switch pollCalls {
|
||||
case 1, 2:
|
||||
return 0, nil
|
||||
default:
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
written, err := serialWrite(context.Background(), serialConfig{}, []byte("OK"), 500*time.Millisecond)
|
||||
if err != nil {
|
||||
t.Fatalf("serialWrite() error = %v", err)
|
||||
}
|
||||
if written != 2 {
|
||||
t.Fatalf("serialWrite() wrote %d bytes, want 2", written)
|
||||
}
|
||||
if pollCalls != 4 {
|
||||
t.Fatalf("poll calls = %d, want 4", pollCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerialWriteTimesOutAfterRepeatedEmptyPolls(t *testing.T) {
|
||||
now := time.Unix(0, 0)
|
||||
stubUnixSerialIO(t, &now)
|
||||
|
||||
unixSerialPollWrite = func(fd int, src []byte, timeout time.Duration) (int, error) {
|
||||
now = now.Add(timeout)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
written, err := serialWrite(context.Background(), serialConfig{}, []byte("A"), 250*time.Millisecond)
|
||||
if err == nil || err.Error() != "timeout while writing serial data" {
|
||||
t.Fatalf("serialWrite() error = %v, want timeout", err)
|
||||
}
|
||||
if written != 0 {
|
||||
t.Fatalf("serialWrite() wrote %d bytes, want 0", written)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
package hardwaretools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -95,13 +96,21 @@ func serialListPorts() ([]serialPortInfo, error) {
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time.Duration) ([]byte, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer windows.CloseHandle(handle)
|
||||
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := make([]byte, length)
|
||||
var read uint32
|
||||
if err := windows.ReadFile(handle, buf, &read, nil); err != nil {
|
||||
@@ -110,13 +119,21 @@ func serialRead(cfg serialConfig, length int, timeout time.Duration) ([]byte, er
|
||||
return buf[:read], nil
|
||||
}
|
||||
|
||||
func serialWrite(cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout time.Duration) (int, error) {
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
handle, err := openAndConfigureWindowsSerial(cfg, timeout)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer windows.CloseHandle(handle)
|
||||
|
||||
if err := serialContextErr(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var written uint32
|
||||
if err := windows.WriteFile(handle, data, &written, nil); err != nil {
|
||||
return int(written), err
|
||||
|
||||
Reference in New Issue
Block a user