From 0276554d9c90d9782c999b0b61db0f3cea10584d Mon Sep 17 00:00:00 2001 From: perhapzz Date: Fri, 20 Mar 2026 10:41:08 +0000 Subject: [PATCH] test(fileutil,health): add unit tests for WriteFileAtomic and health server Add comprehensive test coverage for two previously untested packages: pkg/fileutil (9 tests): - Basic write and read-back - File permissions (0600) - Overwrite existing files - Empty data handling - Nested directory auto-creation - No temp files left after success - Large file (1MB) handling - Concurrent write safety - Invalid path error handling pkg/health (15 tests): - Health endpoint returns 200 with status, uptime, pid - Ready endpoint returns 503 when not ready - Ready endpoint returns 200 when ready - Ready fails when any registered check fails - Ready passes with all checks passing - Reload rejects non-POST methods - Reload returns 503 when no reload func set - Reload calls registered function on success - Reload returns 500 on function error - SetReady toggle behavior - Multiple health checks interaction - RegisterOnMux works with custom ServeMux - NewServer defaults - StartContext graceful shutdown on cancel - statusString helper --- pkg/fileutil/file_test.go | 176 ++++++++++++++++++++ pkg/health/server_test.go | 337 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 513 insertions(+) create mode 100644 pkg/fileutil/file_test.go create mode 100644 pkg/health/server_test.go diff --git a/pkg/fileutil/file_test.go b/pkg/fileutil/file_test.go new file mode 100644 index 000000000..b0494d0d3 --- /dev/null +++ b/pkg/fileutil/file_test.go @@ -0,0 +1,176 @@ +package fileutil + +import ( + "os" + "path/filepath" + "sync" + "testing" +) + +func TestWriteFileAtomic_Basic(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + data := []byte("hello picoclaw") + + err := WriteFileAtomic(path, data, 0o644) + if err != nil { + t.Fatalf("WriteFileAtomic failed: %v", err) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if string(got) != string(data) { + t.Errorf("got %q, want %q", got, data) + } +} + +func TestWriteFileAtomic_Permissions(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "secret.txt") + + err := WriteFileAtomic(path, []byte("secret"), 0o600) + if err != nil { + t.Fatalf("WriteFileAtomic failed: %v", err) + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + // On Unix, check file mode (ignoring directory bits) + if got := info.Mode().Perm(); got != 0o600 { + t.Errorf("permissions = %o, want %o", got, 0o600) + } +} + +func TestWriteFileAtomic_Overwrite(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "overwrite.txt") + + // Write initial content + if err := WriteFileAtomic(path, []byte("old"), 0o644); err != nil { + t.Fatalf("first write failed: %v", err) + } + + // Overwrite + if err := WriteFileAtomic(path, []byte("new"), 0o644); err != nil { + t.Fatalf("second write failed: %v", err) + } + + got, _ := os.ReadFile(path) + if string(got) != "new" { + t.Errorf("got %q after overwrite, want %q", got, "new") + } +} + +func TestWriteFileAtomic_EmptyData(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.txt") + + err := WriteFileAtomic(path, []byte{}, 0o644) + if err != nil { + t.Fatalf("WriteFileAtomic with empty data failed: %v", err) + } + + got, _ := os.ReadFile(path) + if len(got) != 0 { + t.Errorf("expected empty file, got %d bytes", len(got)) + } +} + +func TestWriteFileAtomic_CreatesParentDirs(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "a", "b", "c", "deep.txt") + + err := WriteFileAtomic(path, []byte("deep"), 0o644) + if err != nil { + t.Fatalf("WriteFileAtomic with nested dirs failed: %v", err) + } + + got, _ := os.ReadFile(path) + if string(got) != "deep" { + t.Errorf("got %q, want %q", got, "deep") + } +} + +func TestWriteFileAtomic_NoTempFileOnSuccess(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "clean.txt") + + if err := WriteFileAtomic(path, []byte("data"), 0o644); err != nil { + t.Fatalf("WriteFileAtomic failed: %v", err) + } + + // Verify no temp files remain + entries, _ := os.ReadDir(dir) + for _, e := range entries { + if e.Name() != "clean.txt" { + t.Errorf("unexpected file remaining: %s", e.Name()) + } + } +} + +func TestWriteFileAtomic_LargeFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "large.bin") + + // 1MB of data + data := make([]byte, 1<<20) + for i := range data { + data[i] = byte(i % 256) + } + + if err := WriteFileAtomic(path, data, 0o644); err != nil { + t.Fatalf("WriteFileAtomic with large file failed: %v", err) + } + + got, _ := os.ReadFile(path) + if len(got) != len(data) { + t.Errorf("file size = %d, want %d", len(got), len(data)) + } +} + +func TestWriteFileAtomic_Concurrent(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "concurrent.txt") + + var wg sync.WaitGroup + errs := make(chan error, 10) + + for i := 0; i < 10; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + data := []byte(string(rune('A' + n))) + if err := WriteFileAtomic(path, data, 0o644); err != nil { + errs <- err + } + }(i) + } + + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("concurrent write error: %v", err) + } + + // File should exist and contain exactly 1 byte (last writer wins) + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile after concurrent writes failed: %v", err) + } + if len(got) != 1 { + t.Errorf("expected 1 byte after concurrent writes, got %d", len(got)) + } +} + +func TestWriteFileAtomic_InvalidPath(t *testing.T) { + // /dev/null/impossible is not a valid path on any OS + err := WriteFileAtomic("/dev/null/impossible/file.txt", []byte("data"), 0o644) + if err == nil { + t.Error("expected error for invalid path, got nil") + } +} diff --git a/pkg/health/server_test.go b/pkg/health/server_test.go new file mode 100644 index 000000000..3e71be62a --- /dev/null +++ b/pkg/health/server_test.go @@ -0,0 +1,337 @@ +package health + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func newTestServer() *Server { + s := &Server{ + ready: false, + checks: make(map[string]Check), + startTime: time.Now(), + } + return s +} + +func TestHealthHandler_ReturnsOK(t *testing.T) { + s := newTestServer() + req := httptest.NewRequest(http.MethodGet, "/health", nil) + w := httptest.NewRecorder() + + s.healthHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("health status = %d, want %d", w.Code, http.StatusOK) + } + + var resp StatusResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Status != "ok" { + t.Errorf("status = %q, want %q", resp.Status, "ok") + } + if resp.Pid == 0 { + t.Error("pid should not be 0") + } + if resp.Uptime == "" { + t.Error("uptime should not be empty") + } +} + +func TestReadyHandler_NotReady(t *testing.T) { + s := newTestServer() + // s.ready defaults to false + req := httptest.NewRequest(http.MethodGet, "/ready", nil) + w := httptest.NewRecorder() + + s.readyHandler(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("ready status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } + + var resp StatusResponse + json.NewDecoder(w.Body).Decode(&resp) + if resp.Status != "not ready" { + t.Errorf("status = %q, want %q", resp.Status, "not ready") + } +} + +func TestReadyHandler_Ready(t *testing.T) { + s := newTestServer() + s.SetReady(true) + + req := httptest.NewRequest(http.MethodGet, "/ready", nil) + w := httptest.NewRecorder() + + s.readyHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("ready status = %d, want %d", w.Code, http.StatusOK) + } + + var resp StatusResponse + json.NewDecoder(w.Body).Decode(&resp) + if resp.Status != "ready" { + t.Errorf("status = %q, want %q", resp.Status, "ready") + } +} + +func TestReadyHandler_FailedCheck(t *testing.T) { + s := newTestServer() + s.SetReady(true) + + // Register a failing check + s.RegisterCheck("database", func() (bool, string) { + return false, "connection refused" + }) + + req := httptest.NewRequest(http.MethodGet, "/ready", nil) + w := httptest.NewRecorder() + + s.readyHandler(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("ready with failed check = %d, want %d", w.Code, http.StatusServiceUnavailable) + } + + var resp StatusResponse + json.NewDecoder(w.Body).Decode(&resp) + if resp.Status != "not ready" { + t.Errorf("status = %q, want %q", resp.Status, "not ready") + } + check, ok := resp.Checks["database"] + if !ok { + t.Fatal("missing database check in response") + } + if check.Status != "fail" { + t.Errorf("check status = %q, want %q", check.Status, "fail") + } + if check.Message != "connection refused" { + t.Errorf("check message = %q, want %q", check.Message, "connection refused") + } +} + +func TestReadyHandler_PassingCheck(t *testing.T) { + s := newTestServer() + s.SetReady(true) + + s.RegisterCheck("redis", func() (bool, string) { + return true, "connected" + }) + + req := httptest.NewRequest(http.MethodGet, "/ready", nil) + w := httptest.NewRecorder() + + s.readyHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("ready with passing check = %d, want %d", w.Code, http.StatusOK) + } + + var resp StatusResponse + json.NewDecoder(w.Body).Decode(&resp) + if resp.Checks["redis"].Status != "ok" { + t.Errorf("redis check status = %q, want %q", resp.Checks["redis"].Status, "ok") + } +} + +func TestReloadHandler_MethodNotAllowed(t *testing.T) { + s := newTestServer() + + req := httptest.NewRequest(http.MethodGet, "/reload", nil) + w := httptest.NewRecorder() + + s.reloadHandler(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("reload GET status = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } +} + +func TestReloadHandler_NoReloadFunc(t *testing.T) { + s := newTestServer() + + req := httptest.NewRequest(http.MethodPost, "/reload", nil) + w := httptest.NewRecorder() + + s.reloadHandler(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("reload without func = %d, want %d", w.Code, http.StatusServiceUnavailable) + } +} + +func TestReloadHandler_Success(t *testing.T) { + s := newTestServer() + called := false + s.SetReloadFunc(func() error { + called = true + return nil + }) + + req := httptest.NewRequest(http.MethodPost, "/reload", nil) + w := httptest.NewRecorder() + + s.reloadHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("reload status = %d, want %d", w.Code, http.StatusOK) + } + if !called { + t.Error("reload function was not called") + } +} + +func TestReloadHandler_Error(t *testing.T) { + s := newTestServer() + s.SetReloadFunc(func() error { + return fmt.Errorf("config parse error") + }) + + req := httptest.NewRequest(http.MethodPost, "/reload", nil) + w := httptest.NewRecorder() + + s.reloadHandler(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("reload error status = %d, want %d", w.Code, http.StatusInternalServerError) + } +} + +func TestSetReady_Toggle(t *testing.T) { + s := newTestServer() + + s.SetReady(true) + req := httptest.NewRequest(http.MethodGet, "/ready", nil) + w := httptest.NewRecorder() + s.readyHandler(w, req) + if w.Code != http.StatusOK { + t.Errorf("after SetReady(true): status = %d, want %d", w.Code, http.StatusOK) + } + + s.SetReady(false) + w = httptest.NewRecorder() + s.readyHandler(w, httptest.NewRequest(http.MethodGet, "/ready", nil)) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("after SetReady(false): status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } +} + +func TestRegisterCheck_MultipleChecks(t *testing.T) { + s := newTestServer() + s.SetReady(true) + + s.RegisterCheck("db", func() (bool, string) { + return true, "ok" + }) + s.RegisterCheck("cache", func() (bool, string) { + return true, "ok" + }) + s.RegisterCheck("queue", func() (bool, string) { + return false, "timeout" + }) + + req := httptest.NewRequest(http.MethodGet, "/ready", nil) + w := httptest.NewRecorder() + s.readyHandler(w, req) + + // Should be not ready because queue check fails + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want %d (queue check failed)", w.Code, http.StatusServiceUnavailable) + } + + var resp StatusResponse + json.NewDecoder(w.Body).Decode(&resp) + if len(resp.Checks) != 3 { + t.Errorf("checks count = %d, want 3", len(resp.Checks)) + } +} + +func TestRegisterOnMux(t *testing.T) { + s := newTestServer() + s.SetReady(true) + + mux := http.NewServeMux() + s.RegisterOnMux(mux) + + // Test /health on custom mux + req := httptest.NewRequest(http.MethodGet, "/health", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("/health on custom mux = %d, want %d", w.Code, http.StatusOK) + } + + // Test /ready on custom mux + req = httptest.NewRequest(http.MethodGet, "/ready", nil) + w = httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("/ready on custom mux = %d, want %d", w.Code, http.StatusOK) + } +} + +func TestNewServer(t *testing.T) { + s := NewServer("127.0.0.1", 0) + if s == nil { + t.Fatal("NewServer returned nil") + } + if s.ready { + t.Error("new server should not be ready by default") + } + if s.checks == nil { + t.Error("checks map should be initialized") + } +} + +func TestStartContext_Cancellation(t *testing.T) { + s := NewServer("127.0.0.1", 0) + + ctx, cancel := context.WithCancel(context.Background()) + + errCh := make(chan error, 1) + go func() { + errCh <- s.StartContext(ctx) + }() + + // Give server time to start + time.Sleep(50 * time.Millisecond) + + // Cancel context should trigger shutdown + cancel() + + select { + case err := <-errCh: + if err != nil { + t.Errorf("StartContext returned unexpected error: %v", err) + } + case <-time.After(2 * time.Second): + t.Error("StartContext did not return after context cancellation") + } +} + +func TestStatusString(t *testing.T) { + tests := []struct { + input bool + want string + }{ + {true, "ok"}, + {false, "fail"}, + } + for _, tt := range tests { + got := statusString(tt.input) + if got != tt.want { + t.Errorf("statusString(%v) = %q, want %q", tt.input, got, tt.want) + } + } +}