summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/cli/cli.go8
-rw-r--r--internal/goprecords/aggregate.go9
-rw-r--r--internal/goprecords/aggregate_test.go56
-rw-r--r--internal/goprecords/db.go4
-rw-r--r--internal/goprecords/db_test.go84
-rw-r--r--internal/goprecords/integration_test_runner.go2
-rw-r--r--internal/storage/db.go14
7 files changed, 96 insertions, 81 deletions
diff --git a/internal/cli/cli.go b/internal/cli/cli.go
index 1085ea3..5df3823 100644
--- a/internal/cli/cli.go
+++ b/internal/cli/cli.go
@@ -47,12 +47,12 @@ func runImport(args []string) error {
fs.Usage()
return fmt.Errorf("missing -stats-dir")
}
- db, err := goprecords.OpenDB(*dbPath)
+ ctx := context.Background()
+ db, err := goprecords.OpenDB(ctx, *dbPath)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
- ctx := context.Background()
if err := goprecords.CreateSchema(ctx, db); err != nil {
return fmt.Errorf("schema: %w", err)
}
@@ -70,12 +70,12 @@ func runQuery(args []string) error {
if err := fs.Parse(args); err != nil {
return err
}
- db, err := goprecords.OpenDB(*dbPath)
+ ctx := context.Background()
+ db, err := goprecords.OpenDB(ctx, *dbPath)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
- ctx := context.Background()
aggregates, err := goprecords.LoadAggregates(ctx, db)
if err != nil {
return fmt.Errorf("load: %w", err)
diff --git a/internal/goprecords/aggregate.go b/internal/goprecords/aggregate.go
index 7f7528c..c143426 100644
--- a/internal/goprecords/aggregate.go
+++ b/internal/goprecords/aggregate.go
@@ -55,7 +55,7 @@ func (ag *Aggregator) Aggregate(ctx context.Context) (*Aggregates, error) {
if _, exists := out.Host[host]; exists {
return nil, fmt.Errorf("record file for %s already processed - duplicate inputs?", host)
}
- lastKernel, err := lastKernelFromFile(path)
+ lastKernel, err := lastKernelFromFile(ctx, path)
if err != nil {
return nil, fmt.Errorf("last kernel %s: %w", path, err)
}
@@ -96,7 +96,7 @@ func processRecordsFile(ctx context.Context, path, host string, out *Aggregates)
return nil
}
-func lastKernelFromFile(path string) (string, error) {
+func lastKernelFromFile(ctx context.Context, path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
@@ -106,6 +106,11 @@ func lastKernelFromFile(path string) (string, error) {
var lastOS string
sc := bufio.NewScanner(f)
for sc.Scan() {
+ select {
+ case <-ctx.Done():
+ return "", ctx.Err()
+ default:
+ }
rec, ok := parseRecordLine(sc.Text())
if !ok {
continue
diff --git a/internal/goprecords/aggregate_test.go b/internal/goprecords/aggregate_test.go
index ec19f07..dfc9c91 100644
--- a/internal/goprecords/aggregate_test.go
+++ b/internal/goprecords/aggregate_test.go
@@ -17,7 +17,7 @@ func TestNewAggregator(t *testing.T) {
func TestAggregateInvalidDir(t *testing.T) {
agg := NewAggregator("/nonexistent/path")
ctx := context.Background()
-
+
_, err := agg.Aggregate(ctx)
if err == nil {
t.Error("expected error for non-existent directory")
@@ -29,27 +29,27 @@ func TestAggregateFixtures(t *testing.T) {
if _, err := os.Stat(fixturesPath); err != nil {
fixturesPath = "../../../fixtures"
}
-
+
if _, err := os.Stat(fixturesPath); err != nil {
t.Skipf("skipping test, fixtures directory not found")
}
-
+
agg := NewAggregator(fixturesPath)
ctx := context.Background()
-
+
aggregates, err := agg.Aggregate(ctx)
if err != nil {
t.Fatalf("failed to aggregate fixtures: %v", err)
}
-
+
if aggregates == nil {
t.Error("expected non-nil aggregates")
}
-
+
if len(aggregates.Host) == 0 {
t.Error("expected hosts in aggregates")
}
-
+
if len(aggregates.Kernel) == 0 {
t.Error("expected kernels in aggregates")
}
@@ -60,19 +60,19 @@ func TestAggregateFixturesContent(t *testing.T) {
if _, err := os.Stat(fixturesPath); err != nil {
fixturesPath = "../../../fixtures"
}
-
+
if _, err := os.Stat(fixturesPath); err != nil {
t.Skipf("skipping test, fixtures directory not found")
}
-
+
agg := NewAggregator(fixturesPath)
ctx := context.Background()
-
+
aggregates, err := agg.Aggregate(ctx)
if err != nil {
t.Fatalf("failed to aggregate fixtures: %v", err)
}
-
+
// Check a specific host
if host, ok := aggregates.Host["earth"]; ok {
if host.Boots == 0 {
@@ -91,17 +91,17 @@ func TestAggregateFixturesContent(t *testing.T) {
func TestGetOrNewAggregate(t *testing.T) {
m := make(map[string]*Aggregate)
-
+
agg1 := getOrNewAggregate(m, "kernel1")
if agg1.Name != "kernel1" {
t.Errorf("expected name kernel1, got %q", agg1.Name)
}
-
+
agg2 := getOrNewAggregate(m, "kernel1")
if agg2 != agg1 {
t.Error("expected same aggregate on second call")
}
-
+
if len(m) != 1 {
t.Errorf("expected 1 entry in map, got %d", len(m))
}
@@ -113,23 +113,23 @@ func TestLastKernelFromFile(t *testing.T) {
if _, err := os.Stat(testFile); err != nil {
testFile = "../../../fixtures/earth.records"
}
-
+
if _, err := os.Stat(testFile); err != nil {
t.Skipf("skipping test, fixture file not found")
}
-
- kernel, err := lastKernelFromFile(testFile)
+
+ kernel, err := lastKernelFromFile(context.Background(), testFile)
if err != nil {
t.Fatalf("failed to get last kernel: %v", err)
}
-
+
if kernel == "" {
t.Error("expected non-empty kernel string")
}
}
func TestLastKernelFromFileNonExistent(t *testing.T) {
- _, err := lastKernelFromFile("/nonexistent/file.records")
+ _, err := lastKernelFromFile(context.Background(), "/nonexistent/file.records")
if err == nil {
t.Error("expected error for non-existent file")
}
@@ -139,31 +139,31 @@ func TestProcessRecordsFile(t *testing.T) {
// Create a temporary test file
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.records")
-
+
content := []byte("86400:1000000:Linux 5.10.0-test\n" +
"86400:1000001:Linux 5.10.0-test\n")
-
+
if err := os.WriteFile(testFile, content, 0644); err != nil {
t.Fatalf("failed to create test file: %v", err)
}
-
+
aggs := &Aggregates{
Host: make(map[string]*HostAggregate),
Kernel: make(map[string]*Aggregate),
KernelMajor: make(map[string]*Aggregate),
KernelName: make(map[string]*Aggregate),
}
-
+
// Add host
aggs.Host["test"] = NewHostAggregate("test", "")
-
+
ctx := context.Background()
err := processRecordsFile(ctx, testFile, "test", aggs)
-
+
if err != nil {
t.Fatalf("failed to process records: %v", err)
}
-
+
if aggs.Host["test"].Boots != 2 {
t.Errorf("expected 2 boots, got %d", aggs.Host["test"].Boots)
}
@@ -171,10 +171,10 @@ func TestProcessRecordsFile(t *testing.T) {
func TestContextCancellation(t *testing.T) {
agg := NewAggregator("./fixtures")
-
+
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
-
+
_, err := agg.Aggregate(ctx)
if err == nil {
t.Error("expected error for cancelled context")
diff --git a/internal/goprecords/db.go b/internal/goprecords/db.go
index a07f6be..430701b 100644
--- a/internal/goprecords/db.go
+++ b/internal/goprecords/db.go
@@ -9,8 +9,8 @@ import (
)
// OpenDB opens the SQLite database at path, creating the file if needed.
-func OpenDB(path string) (*sql.DB, error) {
- return storage.Open(path)
+func OpenDB(ctx context.Context, path string) (*sql.DB, error) {
+ return storage.Open(ctx, path)
}
// CreateSchema creates the record table and indexes (idempotent).
diff --git a/internal/goprecords/db_test.go b/internal/goprecords/db_test.go
index dfeda78..943ee4d 100644
--- a/internal/goprecords/db_test.go
+++ b/internal/goprecords/db_test.go
@@ -10,13 +10,13 @@ import (
func TestOpenDB(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
-
- db, err := OpenDB(dbPath)
+
+ db, err := OpenDB(context.Background(), dbPath)
if err != nil {
t.Fatalf("failed to open DB: %v", err)
}
defer db.Close()
-
+
if db == nil {
t.Error("expected non-nil database")
}
@@ -25,19 +25,19 @@ func TestOpenDB(t *testing.T) {
func TestCreateSchema(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
-
- db, err := OpenDB(dbPath)
+
+ db, err := OpenDB(context.Background(), dbPath)
if err != nil {
t.Fatalf("failed to open DB: %v", err)
}
defer db.Close()
-
+
ctx := context.Background()
err = CreateSchema(ctx, db)
if err != nil {
t.Fatalf("failed to create schema: %v", err)
}
-
+
// Verify schema was created by checking if we can query it
_, err = db.ExecContext(ctx, "SELECT 1 FROM record LIMIT 1")
if err != nil {
@@ -48,39 +48,39 @@ func TestCreateSchema(t *testing.T) {
func TestResetRecords(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
-
- db, err := OpenDB(dbPath)
+
+ db, err := OpenDB(context.Background(), dbPath)
if err != nil {
t.Fatalf("failed to open DB: %v", err)
}
defer db.Close()
-
+
ctx := context.Background()
if err := CreateSchema(ctx, db); err != nil {
t.Fatalf("failed to create schema: %v", err)
}
-
+
// Insert a record
- _, err = db.ExecContext(ctx,
+ _, err = db.ExecContext(ctx,
"INSERT INTO record (host, uptime_sec, boot_time, os, os_kernel_name, os_kernel_major) VALUES (?, ?, ?, ?, ?, ?)",
"host1", 1000, 2000, "Linux 5.10", "Linux", "Linux 5...")
if err != nil {
t.Fatalf("failed to insert record: %v", err)
}
-
+
// Reset records
err = ResetRecords(ctx, db)
if err != nil {
t.Fatalf("failed to reset records: %v", err)
}
-
+
// Verify records are empty
var count int
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM record").Scan(&count)
if err != nil {
t.Fatalf("failed to count records: %v", err)
}
-
+
if count != 0 {
t.Errorf("expected 0 records after reset, got %d", count)
}
@@ -89,43 +89,43 @@ func TestResetRecords(t *testing.T) {
func TestImportFromDir(t *testing.T) {
// Create temp directory with test records
tmpDir := t.TempDir()
-
+
// Create a test records file
recordsFile := filepath.Join(tmpDir, "testhost.records")
content := []byte("86400:1000000:Linux 5.10.0-test\n" +
"86400:1000001:Linux 5.10.0-test\n" +
"86400:1000002:Linux 5.10.0-test\n")
-
+
if err := os.WriteFile(recordsFile, content, 0644); err != nil {
t.Fatalf("failed to create test file: %v", err)
}
-
+
// Create database
dbPath := filepath.Join(tmpDir, "test.db")
- db, err := OpenDB(dbPath)
+ db, err := OpenDB(context.Background(), dbPath)
if err != nil {
t.Fatalf("failed to open DB: %v", err)
}
defer db.Close()
-
+
ctx := context.Background()
if err := CreateSchema(ctx, db); err != nil {
t.Fatalf("failed to create schema: %v", err)
}
-
+
// Import records
err = ImportFromDir(ctx, db, tmpDir)
if err != nil {
t.Fatalf("failed to import records: %v", err)
}
-
+
// Verify records were imported
var count int
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM record").Scan(&count)
if err != nil {
t.Fatalf("failed to count records: %v", err)
}
-
+
if count != 3 {
t.Errorf("expected 3 records after import, got %d", count)
}
@@ -134,18 +134,18 @@ func TestImportFromDir(t *testing.T) {
func TestImportFromDirInvalidPath(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
-
- db, err := OpenDB(dbPath)
+
+ db, err := OpenDB(context.Background(), dbPath)
if err != nil {
t.Fatalf("failed to open DB: %v", err)
}
defer db.Close()
-
+
ctx := context.Background()
if err := CreateSchema(ctx, db); err != nil {
t.Fatalf("failed to create schema: %v", err)
}
-
+
// Try to import from non-existent directory
err = ImportFromDir(ctx, db, "/nonexistent/path")
if err == nil {
@@ -156,18 +156,18 @@ func TestImportFromDirInvalidPath(t *testing.T) {
func TestLoadAggregates(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
-
- db, err := OpenDB(dbPath)
+
+ db, err := OpenDB(context.Background(), dbPath)
if err != nil {
t.Fatalf("failed to open DB: %v", err)
}
defer db.Close()
-
+
ctx := context.Background()
if err := CreateSchema(ctx, db); err != nil {
t.Fatalf("failed to create schema: %v", err)
}
-
+
// Insert some records
_, err = db.ExecContext(ctx,
"INSERT INTO record (host, uptime_sec, boot_time, os, os_kernel_name, os_kernel_major) VALUES (?, ?, ?, ?, ?, ?)",
@@ -175,28 +175,28 @@ func TestLoadAggregates(t *testing.T) {
if err != nil {
t.Fatalf("failed to insert: %v", err)
}
-
+
_, err = db.ExecContext(ctx,
"INSERT INTO record (host, uptime_sec, boot_time, os, os_kernel_name, os_kernel_major) VALUES (?, ?, ?, ?, ?, ?)",
"host1", 2000, 3000, "Linux 5.11", "Linux", "Linux 5...")
if err != nil {
t.Fatalf("failed to insert: %v", err)
}
-
+
// Load aggregates
aggs, err := LoadAggregates(ctx, db)
if err != nil {
t.Fatalf("failed to load aggregates: %v", err)
}
-
+
if aggs == nil {
t.Error("expected non-nil aggregates")
}
-
+
if len(aggs.Host) != 1 {
t.Errorf("expected 1 host, got %d", len(aggs.Host))
}
-
+
if host, ok := aggs.Host["host1"]; ok {
if host.Boots != 2 {
t.Errorf("expected 2 boots, got %d", host.Boots)
@@ -210,28 +210,28 @@ func TestLoadAggregates(t *testing.T) {
func TestLoadAggregatesEmptyDB(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
-
- db, err := OpenDB(dbPath)
+
+ db, err := OpenDB(context.Background(), dbPath)
if err != nil {
t.Fatalf("failed to open DB: %v", err)
}
defer db.Close()
-
+
ctx := context.Background()
if err := CreateSchema(ctx, db); err != nil {
t.Fatalf("failed to create schema: %v", err)
}
-
+
// Load from empty database
aggs, err := LoadAggregates(ctx, db)
if err != nil {
t.Fatalf("failed to load aggregates: %v", err)
}
-
+
if aggs == nil {
t.Error("expected non-nil aggregates")
}
-
+
if len(aggs.Host) != 0 {
t.Errorf("expected 0 hosts, got %d", len(aggs.Host))
}
diff --git a/internal/goprecords/integration_test_runner.go b/internal/goprecords/integration_test_runner.go
index 29e5f24..ff4a8c3 100644
--- a/internal/goprecords/integration_test_runner.go
+++ b/internal/goprecords/integration_test_runner.go
@@ -85,7 +85,7 @@ func testImportExport(ctx context.Context, aggregates *Aggregates, fixturesDir s
tmpDB := fixturesDir + "/test_import.db"
os.Remove(tmpDB)
failed := 0
- db, err := OpenDB(tmpDB)
+ db, err := OpenDB(ctx, tmpDB)
if err != nil {
fmt.Printf("FAIL: open tmp db: %v\n", err)
return 1
diff --git a/internal/storage/db.go b/internal/storage/db.go
index 127f45d..ea9d764 100644
--- a/internal/storage/db.go
+++ b/internal/storage/db.go
@@ -37,12 +37,12 @@ type Record struct {
KernelMajor string
}
-func Open(path string) (*sql.DB, error) {
+func Open(ctx context.Context, path string) (*sql.DB, error) {
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, err
}
- if _, err := db.Exec("PRAGMA foreign_keys = OFF"); err != nil {
+ if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil {
db.Close()
return nil, err
}
@@ -108,6 +108,11 @@ func LoadRecords(ctx context.Context, db *sql.DB) ([]Record, error) {
defer rows.Close()
var out []Record
for rows.Next() {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
var rec Record
var uptimeSec, bootTime int64
if err := rows.Scan(&rec.Host, &uptimeSec, &bootTime, &rec.OS, &rec.KernelName, &rec.KernelMajor); err != nil {
@@ -139,6 +144,11 @@ func importFile(ctx context.Context, insert *sql.Stmt, path, host string) error
defer f.Close()
sc := bufio.NewScanner(f)
for sc.Scan() {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
rec, ok := parseRecordLine(sc.Text())
if !ok {
continue