diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-03 20:09:53 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-03 20:09:53 +0200 |
| commit | e17ca48fb49742b2aaccdf850d526e0cb8ad8ad6 (patch) | |
| tree | 9c24f5ecbf9b043818fdc7ef174f538e34deef7f /internal | |
| parent | 81e89f8e67a2eef0c8815f9d3ac79ee67530e9d3 (diff) | |
refactor(context): propagate context through db open and io scans (task 333)
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/cli/cli.go | 8 | ||||
| -rw-r--r-- | internal/goprecords/aggregate.go | 9 | ||||
| -rw-r--r-- | internal/goprecords/aggregate_test.go | 56 | ||||
| -rw-r--r-- | internal/goprecords/db.go | 4 | ||||
| -rw-r--r-- | internal/goprecords/db_test.go | 84 | ||||
| -rw-r--r-- | internal/goprecords/integration_test_runner.go | 2 | ||||
| -rw-r--r-- | internal/storage/db.go | 14 |
7 files changed, 96 insertions, 81 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 1085ea3..5df3823 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -47,12 +47,12 @@ func runImport(args []string) error { fs.Usage() return fmt.Errorf("missing -stats-dir") } - db, err := goprecords.OpenDB(*dbPath) + ctx := context.Background() + db, err := goprecords.OpenDB(ctx, *dbPath) if err != nil { return fmt.Errorf("open db: %w", err) } defer db.Close() - ctx := context.Background() if err := goprecords.CreateSchema(ctx, db); err != nil { return fmt.Errorf("schema: %w", err) } @@ -70,12 +70,12 @@ func runQuery(args []string) error { if err := fs.Parse(args); err != nil { return err } - db, err := goprecords.OpenDB(*dbPath) + ctx := context.Background() + db, err := goprecords.OpenDB(ctx, *dbPath) if err != nil { return fmt.Errorf("open db: %w", err) } defer db.Close() - ctx := context.Background() aggregates, err := goprecords.LoadAggregates(ctx, db) if err != nil { return fmt.Errorf("load: %w", err) diff --git a/internal/goprecords/aggregate.go b/internal/goprecords/aggregate.go index 7f7528c..c143426 100644 --- a/internal/goprecords/aggregate.go +++ b/internal/goprecords/aggregate.go @@ -55,7 +55,7 @@ func (ag *Aggregator) Aggregate(ctx context.Context) (*Aggregates, error) { if _, exists := out.Host[host]; exists { return nil, fmt.Errorf("record file for %s already processed - duplicate inputs?", host) } - lastKernel, err := lastKernelFromFile(path) + lastKernel, err := lastKernelFromFile(ctx, path) if err != nil { return nil, fmt.Errorf("last kernel %s: %w", path, err) } @@ -96,7 +96,7 @@ func processRecordsFile(ctx context.Context, path, host string, out *Aggregates) return nil } -func lastKernelFromFile(path string) (string, error) { +func lastKernelFromFile(ctx context.Context, path string) (string, error) { f, err := os.Open(path) if err != nil { return "", err @@ -106,6 +106,11 @@ func lastKernelFromFile(path string) (string, error) { var lastOS string sc := bufio.NewScanner(f) for sc.Scan() { + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } rec, ok := parseRecordLine(sc.Text()) if !ok { continue diff --git a/internal/goprecords/aggregate_test.go b/internal/goprecords/aggregate_test.go index ec19f07..dfc9c91 100644 --- a/internal/goprecords/aggregate_test.go +++ b/internal/goprecords/aggregate_test.go @@ -17,7 +17,7 @@ func TestNewAggregator(t *testing.T) { func TestAggregateInvalidDir(t *testing.T) { agg := NewAggregator("/nonexistent/path") ctx := context.Background() - + _, err := agg.Aggregate(ctx) if err == nil { t.Error("expected error for non-existent directory") @@ -29,27 +29,27 @@ func TestAggregateFixtures(t *testing.T) { if _, err := os.Stat(fixturesPath); err != nil { fixturesPath = "../../../fixtures" } - + if _, err := os.Stat(fixturesPath); err != nil { t.Skipf("skipping test, fixtures directory not found") } - + agg := NewAggregator(fixturesPath) ctx := context.Background() - + aggregates, err := agg.Aggregate(ctx) if err != nil { t.Fatalf("failed to aggregate fixtures: %v", err) } - + if aggregates == nil { t.Error("expected non-nil aggregates") } - + if len(aggregates.Host) == 0 { t.Error("expected hosts in aggregates") } - + if len(aggregates.Kernel) == 0 { t.Error("expected kernels in aggregates") } @@ -60,19 +60,19 @@ func TestAggregateFixturesContent(t *testing.T) { if _, err := os.Stat(fixturesPath); err != nil { fixturesPath = "../../../fixtures" } - + if _, err := os.Stat(fixturesPath); err != nil { t.Skipf("skipping test, fixtures directory not found") } - + agg := NewAggregator(fixturesPath) ctx := context.Background() - + aggregates, err := agg.Aggregate(ctx) if err != nil { t.Fatalf("failed to aggregate fixtures: %v", err) } - + // Check a specific host if host, ok := aggregates.Host["earth"]; ok { if host.Boots == 0 { @@ -91,17 +91,17 @@ func TestAggregateFixturesContent(t *testing.T) { func TestGetOrNewAggregate(t *testing.T) { m := make(map[string]*Aggregate) - + agg1 := getOrNewAggregate(m, "kernel1") if agg1.Name != "kernel1" { t.Errorf("expected name kernel1, got %q", agg1.Name) } - + agg2 := getOrNewAggregate(m, "kernel1") if agg2 != agg1 { t.Error("expected same aggregate on second call") } - + if len(m) != 1 { t.Errorf("expected 1 entry in map, got %d", len(m)) } @@ -113,23 +113,23 @@ func TestLastKernelFromFile(t *testing.T) { if _, err := os.Stat(testFile); err != nil { testFile = "../../../fixtures/earth.records" } - + if _, err := os.Stat(testFile); err != nil { t.Skipf("skipping test, fixture file not found") } - - kernel, err := lastKernelFromFile(testFile) + + kernel, err := lastKernelFromFile(context.Background(), testFile) if err != nil { t.Fatalf("failed to get last kernel: %v", err) } - + if kernel == "" { t.Error("expected non-empty kernel string") } } func TestLastKernelFromFileNonExistent(t *testing.T) { - _, err := lastKernelFromFile("/nonexistent/file.records") + _, err := lastKernelFromFile(context.Background(), "/nonexistent/file.records") if err == nil { t.Error("expected error for non-existent file") } @@ -139,31 +139,31 @@ func TestProcessRecordsFile(t *testing.T) { // Create a temporary test file tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "test.records") - + content := []byte("86400:1000000:Linux 5.10.0-test\n" + "86400:1000001:Linux 5.10.0-test\n") - + if err := os.WriteFile(testFile, content, 0644); err != nil { t.Fatalf("failed to create test file: %v", err) } - + aggs := &Aggregates{ Host: make(map[string]*HostAggregate), Kernel: make(map[string]*Aggregate), KernelMajor: make(map[string]*Aggregate), KernelName: make(map[string]*Aggregate), } - + // Add host aggs.Host["test"] = NewHostAggregate("test", "") - + ctx := context.Background() err := processRecordsFile(ctx, testFile, "test", aggs) - + if err != nil { t.Fatalf("failed to process records: %v", err) } - + if aggs.Host["test"].Boots != 2 { t.Errorf("expected 2 boots, got %d", aggs.Host["test"].Boots) } @@ -171,10 +171,10 @@ func TestProcessRecordsFile(t *testing.T) { func TestContextCancellation(t *testing.T) { agg := NewAggregator("./fixtures") - + ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - + _, err := agg.Aggregate(ctx) if err == nil { t.Error("expected error for cancelled context") diff --git a/internal/goprecords/db.go b/internal/goprecords/db.go index a07f6be..430701b 100644 --- a/internal/goprecords/db.go +++ b/internal/goprecords/db.go @@ -9,8 +9,8 @@ import ( ) // OpenDB opens the SQLite database at path, creating the file if needed. -func OpenDB(path string) (*sql.DB, error) { - return storage.Open(path) +func OpenDB(ctx context.Context, path string) (*sql.DB, error) { + return storage.Open(ctx, path) } // CreateSchema creates the record table and indexes (idempotent). diff --git a/internal/goprecords/db_test.go b/internal/goprecords/db_test.go index dfeda78..943ee4d 100644 --- a/internal/goprecords/db_test.go +++ b/internal/goprecords/db_test.go @@ -10,13 +10,13 @@ import ( func TestOpenDB(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") - - db, err := OpenDB(dbPath) + + db, err := OpenDB(context.Background(), dbPath) if err != nil { t.Fatalf("failed to open DB: %v", err) } defer db.Close() - + if db == nil { t.Error("expected non-nil database") } @@ -25,19 +25,19 @@ func TestOpenDB(t *testing.T) { func TestCreateSchema(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") - - db, err := OpenDB(dbPath) + + db, err := OpenDB(context.Background(), dbPath) if err != nil { t.Fatalf("failed to open DB: %v", err) } defer db.Close() - + ctx := context.Background() err = CreateSchema(ctx, db) if err != nil { t.Fatalf("failed to create schema: %v", err) } - + // Verify schema was created by checking if we can query it _, err = db.ExecContext(ctx, "SELECT 1 FROM record LIMIT 1") if err != nil { @@ -48,39 +48,39 @@ func TestCreateSchema(t *testing.T) { func TestResetRecords(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") - - db, err := OpenDB(dbPath) + + db, err := OpenDB(context.Background(), dbPath) if err != nil { t.Fatalf("failed to open DB: %v", err) } defer db.Close() - + ctx := context.Background() if err := CreateSchema(ctx, db); err != nil { t.Fatalf("failed to create schema: %v", err) } - + // Insert a record - _, err = db.ExecContext(ctx, + _, err = db.ExecContext(ctx, "INSERT INTO record (host, uptime_sec, boot_time, os, os_kernel_name, os_kernel_major) VALUES (?, ?, ?, ?, ?, ?)", "host1", 1000, 2000, "Linux 5.10", "Linux", "Linux 5...") if err != nil { t.Fatalf("failed to insert record: %v", err) } - + // Reset records err = ResetRecords(ctx, db) if err != nil { t.Fatalf("failed to reset records: %v", err) } - + // Verify records are empty var count int err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM record").Scan(&count) if err != nil { t.Fatalf("failed to count records: %v", err) } - + if count != 0 { t.Errorf("expected 0 records after reset, got %d", count) } @@ -89,43 +89,43 @@ func TestResetRecords(t *testing.T) { func TestImportFromDir(t *testing.T) { // Create temp directory with test records tmpDir := t.TempDir() - + // Create a test records file recordsFile := filepath.Join(tmpDir, "testhost.records") content := []byte("86400:1000000:Linux 5.10.0-test\n" + "86400:1000001:Linux 5.10.0-test\n" + "86400:1000002:Linux 5.10.0-test\n") - + if err := os.WriteFile(recordsFile, content, 0644); err != nil { t.Fatalf("failed to create test file: %v", err) } - + // Create database dbPath := filepath.Join(tmpDir, "test.db") - db, err := OpenDB(dbPath) + db, err := OpenDB(context.Background(), dbPath) if err != nil { t.Fatalf("failed to open DB: %v", err) } defer db.Close() - + ctx := context.Background() if err := CreateSchema(ctx, db); err != nil { t.Fatalf("failed to create schema: %v", err) } - + // Import records err = ImportFromDir(ctx, db, tmpDir) if err != nil { t.Fatalf("failed to import records: %v", err) } - + // Verify records were imported var count int err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM record").Scan(&count) if err != nil { t.Fatalf("failed to count records: %v", err) } - + if count != 3 { t.Errorf("expected 3 records after import, got %d", count) } @@ -134,18 +134,18 @@ func TestImportFromDir(t *testing.T) { func TestImportFromDirInvalidPath(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") - - db, err := OpenDB(dbPath) + + db, err := OpenDB(context.Background(), dbPath) if err != nil { t.Fatalf("failed to open DB: %v", err) } defer db.Close() - + ctx := context.Background() if err := CreateSchema(ctx, db); err != nil { t.Fatalf("failed to create schema: %v", err) } - + // Try to import from non-existent directory err = ImportFromDir(ctx, db, "/nonexistent/path") if err == nil { @@ -156,18 +156,18 @@ func TestImportFromDirInvalidPath(t *testing.T) { func TestLoadAggregates(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") - - db, err := OpenDB(dbPath) + + db, err := OpenDB(context.Background(), dbPath) if err != nil { t.Fatalf("failed to open DB: %v", err) } defer db.Close() - + ctx := context.Background() if err := CreateSchema(ctx, db); err != nil { t.Fatalf("failed to create schema: %v", err) } - + // Insert some records _, err = db.ExecContext(ctx, "INSERT INTO record (host, uptime_sec, boot_time, os, os_kernel_name, os_kernel_major) VALUES (?, ?, ?, ?, ?, ?)", @@ -175,28 +175,28 @@ func TestLoadAggregates(t *testing.T) { if err != nil { t.Fatalf("failed to insert: %v", err) } - + _, err = db.ExecContext(ctx, "INSERT INTO record (host, uptime_sec, boot_time, os, os_kernel_name, os_kernel_major) VALUES (?, ?, ?, ?, ?, ?)", "host1", 2000, 3000, "Linux 5.11", "Linux", "Linux 5...") if err != nil { t.Fatalf("failed to insert: %v", err) } - + // Load aggregates aggs, err := LoadAggregates(ctx, db) if err != nil { t.Fatalf("failed to load aggregates: %v", err) } - + if aggs == nil { t.Error("expected non-nil aggregates") } - + if len(aggs.Host) != 1 { t.Errorf("expected 1 host, got %d", len(aggs.Host)) } - + if host, ok := aggs.Host["host1"]; ok { if host.Boots != 2 { t.Errorf("expected 2 boots, got %d", host.Boots) @@ -210,28 +210,28 @@ func TestLoadAggregates(t *testing.T) { func TestLoadAggregatesEmptyDB(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") - - db, err := OpenDB(dbPath) + + db, err := OpenDB(context.Background(), dbPath) if err != nil { t.Fatalf("failed to open DB: %v", err) } defer db.Close() - + ctx := context.Background() if err := CreateSchema(ctx, db); err != nil { t.Fatalf("failed to create schema: %v", err) } - + // Load from empty database aggs, err := LoadAggregates(ctx, db) if err != nil { t.Fatalf("failed to load aggregates: %v", err) } - + if aggs == nil { t.Error("expected non-nil aggregates") } - + if len(aggs.Host) != 0 { t.Errorf("expected 0 hosts, got %d", len(aggs.Host)) } diff --git a/internal/goprecords/integration_test_runner.go b/internal/goprecords/integration_test_runner.go index 29e5f24..ff4a8c3 100644 --- a/internal/goprecords/integration_test_runner.go +++ b/internal/goprecords/integration_test_runner.go @@ -85,7 +85,7 @@ func testImportExport(ctx context.Context, aggregates *Aggregates, fixturesDir s tmpDB := fixturesDir + "/test_import.db" os.Remove(tmpDB) failed := 0 - db, err := OpenDB(tmpDB) + db, err := OpenDB(ctx, tmpDB) if err != nil { fmt.Printf("FAIL: open tmp db: %v\n", err) return 1 diff --git a/internal/storage/db.go b/internal/storage/db.go index 127f45d..ea9d764 100644 --- a/internal/storage/db.go +++ b/internal/storage/db.go @@ -37,12 +37,12 @@ type Record struct { KernelMajor string } -func Open(path string) (*sql.DB, error) { +func Open(ctx context.Context, path string) (*sql.DB, error) { db, err := sql.Open("sqlite", path) if err != nil { return nil, err } - if _, err := db.Exec("PRAGMA foreign_keys = OFF"); err != nil { + if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil { db.Close() return nil, err } @@ -108,6 +108,11 @@ func LoadRecords(ctx context.Context, db *sql.DB) ([]Record, error) { defer rows.Close() var out []Record for rows.Next() { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } var rec Record var uptimeSec, bootTime int64 if err := rows.Scan(&rec.Host, &uptimeSec, &bootTime, &rec.OS, &rec.KernelName, &rec.KernelMajor); err != nil { @@ -139,6 +144,11 @@ func importFile(ctx context.Context, insert *sql.Stmt, path, host string) error defer f.Close() sc := bufio.NewScanner(f) for sc.Scan() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } rec, ok := parseRecordLine(sc.Text()) if !ok { continue |
