summaryrefslogtreecommitdiff
path: root/internal/storage/db.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/storage/db.go')
-rw-r--r--internal/storage/db.go18
1 files changed, 12 insertions, 6 deletions
diff --git a/internal/storage/db.go b/internal/storage/db.go
index fe789d6..45eb71a 100644
--- a/internal/storage/db.go
+++ b/internal/storage/db.go
@@ -5,6 +5,7 @@ import (
"context"
"database/sql"
"fmt"
+ "io/fs"
"os"
"codeberg.org/snonux/goprecords/internal/recordline"
@@ -63,10 +64,15 @@ func ResetRecords(ctx context.Context, db *sql.DB) error {
}
func ImportFromDir(ctx context.Context, db *sql.DB, statsDir string) error {
+ return ImportFromFS(ctx, db, os.DirFS(statsDir))
+}
+
+// ImportFromFS reads non-empty .records files from the root of fsys into the database.
+func ImportFromFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if err := ResetRecords(ctx, db); err != nil {
return fmt.Errorf("reset records: %w", err)
}
- files, err := recordsdir.ListNonEmptyFiles(statsDir)
+ files, err := recordsdir.ListNonEmptyFilesFS(fsys, ".")
if err != nil {
return fmt.Errorf("read dir: %w", err)
}
@@ -81,7 +87,7 @@ func ImportFromDir(ctx context.Context, db *sql.DB, statsDir string) error {
}
defer insert.Close()
for _, f := range files {
- if err := importFile(ctx, insert, f.Path, f.Host); err != nil {
+ if err := importFile(ctx, insert, fsys, f.Path, f.Host); err != nil {
return err
}
}
@@ -119,10 +125,10 @@ func LoadRecords(ctx context.Context, db *sql.DB) ([]Record, error) {
return out, nil
}
-func importFile(ctx context.Context, insert *sql.Stmt, path, host string) error {
- f, err := os.Open(path)
+func importFile(ctx context.Context, insert *sql.Stmt, fsys fs.FS, relPath, host string) error {
+ f, err := fsys.Open(relPath)
if err != nil {
- return fmt.Errorf("open %s: %w", path, err)
+ return fmt.Errorf("open %s: %w", relPath, err)
}
defer f.Close()
sc := bufio.NewScanner(f)
@@ -141,7 +147,7 @@ func importFile(ctx context.Context, insert *sql.Stmt, path, host string) error
}
}
if err := sc.Err(); err != nil {
- return fmt.Errorf("scan %s: %w", path, err)
+ return fmt.Errorf("scan %s: %w", relPath, err)
}
return nil
}