From 64e48163d043ecdc1bd90eaff6632fb3a68fdf67 Mon Sep 17 00:00:00 2001 From: SiYue-ZO <2835601846@qq.com> Date: Tue, 28 Apr 2026 12:57:25 +0800 Subject: [PATCH] hardware/serial: improve windows I/O handling --- pkg/tools/hardware/serial.go | 34 ++++++++ pkg/tools/hardware/serial_windows.go | 48 ++++++++-- pkg/tools/hardware/serial_windows_test.go | 39 +++++++++ .../hardware/serial_write_common_test.go | 87 +++++++++++++++++++ 4 files changed, 199 insertions(+), 9 deletions(-) create mode 100644 pkg/tools/hardware/serial_windows_test.go create mode 100644 pkg/tools/hardware/serial_write_common_test.go diff --git a/pkg/tools/hardware/serial.go b/pkg/tools/hardware/serial.go index 3a2ead824..52e47fc7a 100644 --- a/pkg/tools/hardware/serial.go +++ b/pkg/tools/hardware/serial.go @@ -413,3 +413,37 @@ func serialContextErr(ctx context.Context) error { return nil } } + +func serialWriteAll( + ctx context.Context, + data []byte, + timeout time.Duration, + now func() time.Time, + write func([]byte) (int, error), +) (int, error) { + if err := serialContextErr(ctx); err != nil { + return 0, err + } + + total := 0 + deadline := now().Add(timeout) + for total < len(data) { + if err := serialContextErr(ctx); err != nil { + return total, err + } + if deadline.Sub(now()) <= 0 { + return total, fmt.Errorf("timeout while writing serial data") + } + + n, err := write(data[total:]) + total += n + if err != nil { + return total, err + } + if n == 0 { + continue + } + } + + return total, nil +} diff --git a/pkg/tools/hardware/serial_windows.go b/pkg/tools/hardware/serial_windows.go index 1dcbfa7a9..36f8b4271 100644 --- a/pkg/tools/hardware/serial_windows.go +++ b/pkg/tools/hardware/serial_windows.go @@ -25,6 +25,17 @@ var ( const ( purgeTxClear = 0x0004 purgeRxClear = 0x0008 + + dcbFlagBinary = 0x00000001 + dcbFlagParity = 0x00000002 + dcbFlagOutxCtsFlow = 0x00000004 + dcbFlagOutxDsrFlow = 0x00000008 + dcbFlagDtrControlMask = 0x00000030 + dcbFlagDsrSensitivity = 0x00000040 + dcbFlagTXContinueOnXoff = 0x00000080 + dcbFlagOutX = 0x00000100 + dcbFlagInX = 0x00000200 + dcbFlagRtsControlMask = 0x00003000 ) type dcb struct { @@ -113,6 +124,8 @@ func serialRead(ctx context.Context, cfg serialConfig, length int, timeout time. buf := make([]byte, length) var read uint32 + // Synchronous serial I/O on Windows cannot be interrupted once the syscall starts. + // COMMTIMEOUTS bounds how long turn cancellation may take to surface. if err := windows.ReadFile(handle, buf, &read, nil); err != nil { return nil, err } @@ -134,11 +147,15 @@ func serialWrite(ctx context.Context, cfg serialConfig, data []byte, timeout tim return 0, err } - var written uint32 - if err := windows.WriteFile(handle, data, &written, nil); err != nil { - return int(written), err - } - return int(written), nil + return serialWriteAll(ctx, data, timeout, time.Now, func(chunk []byte) (int, error) { + var written uint32 + // Like ReadFile above, this synchronous WriteFile call relies on COMMTIMEOUTS + // rather than context preemption once the syscall is in flight. + if err := windows.WriteFile(handle, chunk, &written, nil); err != nil { + return int(written), err + } + return int(written), nil + }) } func openAndConfigureWindowsSerial(cfg serialConfig, timeout time.Duration) (windows.Handle, error) { @@ -171,18 +188,19 @@ func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout state.BaudRate = uint32(cfg.Baud) state.ByteSize = byte(cfg.DataBits) - state.Flags |= 0x00000001 // fBinary + state.Flags = sanitizeWindowsSerialFlags(state.Flags) + state.Flags |= dcbFlagBinary switch cfg.Parity { case "even": state.Parity = 2 - state.Flags |= 0x00000002 // fParity + state.Flags |= dcbFlagParity case "odd": state.Parity = 1 - state.Flags |= 0x00000002 // fParity + state.Flags |= dcbFlagParity default: state.Parity = 0 - state.Flags &^= 0x00000002 + state.Flags &^= dcbFlagParity } switch cfg.StopBits { @@ -216,3 +234,15 @@ func configureWindowsSerialPort(handle windows.Handle, cfg serialConfig, timeout procPurgeComm.Call(uintptr(handle), uintptr(purgeRxClear|purgeTxClear)) return nil } + +func sanitizeWindowsSerialFlags(flags uint32) uint32 { + flags &^= dcbFlagOutxCtsFlow | + dcbFlagOutxDsrFlow | + dcbFlagDtrControlMask | + dcbFlagDsrSensitivity | + dcbFlagTXContinueOnXoff | + dcbFlagOutX | + dcbFlagInX | + dcbFlagRtsControlMask + return flags +} diff --git a/pkg/tools/hardware/serial_windows_test.go b/pkg/tools/hardware/serial_windows_test.go new file mode 100644 index 000000000..ecb0addbd --- /dev/null +++ b/pkg/tools/hardware/serial_windows_test.go @@ -0,0 +1,39 @@ +//go:build windows + +package hardwaretools + +import "testing" + +func TestSanitizeWindowsSerialFlags(t *testing.T) { + flags := uint32( + dcbFlagBinary | + dcbFlagParity | + dcbFlagOutxCtsFlow | + dcbFlagOutxDsrFlow | + dcbFlagDtrControlMask | + dcbFlagDsrSensitivity | + dcbFlagTXContinueOnXoff | + dcbFlagOutX | + dcbFlagInX | + dcbFlagRtsControlMask, + ) + + got := sanitizeWindowsSerialFlags(flags) + + if got&dcbFlagBinary == 0 { + t.Fatal("sanitizeWindowsSerialFlags() should preserve fBinary") + } + if got&dcbFlagParity == 0 { + t.Fatal("sanitizeWindowsSerialFlags() should preserve fParity") + } + if got&(dcbFlagOutxCtsFlow| + dcbFlagOutxDsrFlow| + dcbFlagDtrControlMask| + dcbFlagDsrSensitivity| + dcbFlagTXContinueOnXoff| + dcbFlagOutX| + dcbFlagInX| + dcbFlagRtsControlMask) != 0 { + t.Fatalf("sanitizeWindowsSerialFlags() = %#x, want flow-control bits cleared", got) + } +} diff --git a/pkg/tools/hardware/serial_write_common_test.go b/pkg/tools/hardware/serial_write_common_test.go new file mode 100644 index 000000000..398c1fde5 --- /dev/null +++ b/pkg/tools/hardware/serial_write_common_test.go @@ -0,0 +1,87 @@ +package hardwaretools + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestSerialWriteAllRetriesPartialWritesUntilComplete(t *testing.T) { + now := time.Unix(0, 0) + calls := 0 + + written, err := serialWriteAll(context.Background(), []byte("PING"), time.Second, func() time.Time { + return now + }, func(chunk []byte) (int, error) { + calls++ + now = now.Add(100 * time.Millisecond) + switch calls { + case 1: + if string(chunk) != "PING" { + t.Fatalf("first chunk = %q, want %q", chunk, "PING") + } + return 2, nil + case 2: + if string(chunk) != "NG" { + t.Fatalf("second chunk = %q, want %q", chunk, "NG") + } + return 2, nil + default: + t.Fatalf("unexpected extra write call %d", calls) + return 0, nil + } + }) + if err != nil { + t.Fatalf("serialWriteAll() error = %v", err) + } + if written != 4 { + t.Fatalf("serialWriteAll() wrote %d bytes, want 4", written) + } +} + +func TestSerialWriteAllTimesOutAfterZeroByteWrites(t *testing.T) { + now := time.Unix(0, 0) + calls := 0 + + written, err := serialWriteAll(context.Background(), []byte("A"), 250*time.Millisecond, func() time.Time { + return now + }, func(chunk []byte) (int, error) { + calls++ + now = now.Add(100 * time.Millisecond) + return 0, nil + }) + if err == nil || err.Error() != "timeout while writing serial data" { + t.Fatalf("serialWriteAll() error = %v, want timeout", err) + } + if written != 0 { + t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written) + } + if calls != 3 { + t.Fatalf("write calls = %d, want 3", calls) + } +} + +func TestSerialWriteAllReturnsContextCancellationAfterRetryBoundary(t *testing.T) { + now := time.Unix(0, 0) + ctx, cancel := context.WithCancel(context.Background()) + calls := 0 + + written, err := serialWriteAll(ctx, []byte("A"), time.Second, func() time.Time { + return now + }, func(chunk []byte) (int, error) { + calls++ + now = now.Add(100 * time.Millisecond) + cancel() + return 0, nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("serialWriteAll() error = %v, want context canceled", err) + } + if written != 0 { + t.Fatalf("serialWriteAll() wrote %d bytes, want 0", written) + } + if calls != 1 { + t.Fatalf("write calls = %d, want 1", calls) + } +}