diff options
| -rw-r--r-- | internal/authkeys/store.go | 112 | ||||
| -rw-r--r-- | internal/authkeys/store_test.go | 65 | ||||
| -rw-r--r-- | internal/cli/cli.go | 47 | ||||
| -rw-r--r-- | internal/cli/cli_test.go | 27 | ||||
| -rw-r--r-- | internal/daemon/daemon.go | 18 | ||||
| -rw-r--r-- | internal/daemon/daemon_test.go | 122 | ||||
| -rw-r--r-- | internal/daemon/upload.go | 190 |
7 files changed, 576 insertions, 5 deletions
diff --git a/internal/authkeys/store.go b/internal/authkeys/store.go new file mode 100644 index 0000000..9f0a8b2 --- /dev/null +++ b/internal/authkeys/store.go @@ -0,0 +1,112 @@ +package authkeys + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "database/sql" + "encoding/base64" + "fmt" + "path/filepath" + "time" + + _ "modernc.org/sqlite" +) + +const schemaSQL = ` +CREATE TABLE IF NOT EXISTS client_key ( + hostname TEXT NOT NULL PRIMARY KEY, + key_hash BLOB NOT NULL, + created_at INTEGER NOT NULL +); +` + +// Store holds per-client API key hashes in SQLite. +type Store struct { + db *sql.DB +} + +// DefaultPath returns the default auth database path under statsDir. +func DefaultPath(statsDir string) string { + return filepath.Join(statsDir, "goprecords-auth.db") +} + +// OpenStore opens or creates the SQLite auth database at path. +func OpenStore(ctx context.Context, path string) (*Store, error) { + db, err := sql.Open("sqlite", path) + if err != nil { + return nil, fmt.Errorf("open auth db: %w", err) + } + if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = OFF"); err != nil { + db.Close() + return nil, fmt.Errorf("pragma: %w", err) + } + return &Store{db: db}, nil +} + +// Close releases the database handle. +func (s *Store) Close() error { + if s == nil || s.db == nil { + return nil + } + return s.db.Close() +} + +// EnsureSchema creates the client_key table if missing. +func (s *Store) EnsureSchema(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, schemaSQL) + if err != nil { + return fmt.Errorf("auth schema: %w", err) + } + return nil +} + +// KeyCount returns how many client keys are stored. When zero, upload auth is not enforced. +func (s *Store) KeyCount(ctx context.Context) (int, error) { + var n int + err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM client_key").Scan(&n) + if err != nil { + return 0, fmt.Errorf("count keys: %w", err) + } + return n, nil +} + +// CreateKey inserts or replaces the key for hostname and returns the plaintext token once. +func (s *Store) CreateKey(ctx context.Context, hostname string) (token string, err error) { + if hostname == "" { + return "", fmt.Errorf("empty hostname") + } + raw := make([]byte, 32) + if _, err := rand.Read(raw); err != nil { + return "", fmt.Errorf("random token: %w", err) + } + tok := base64.RawURLEncoding.EncodeToString(raw) + sum := sha256.Sum256([]byte(tok)) + _, err = s.db.ExecContext(ctx, + `INSERT INTO client_key (hostname, key_hash, created_at) VALUES (?, ?, ?) + ON CONFLICT(hostname) DO UPDATE SET key_hash = excluded.key_hash, created_at = excluded.created_at`, + hostname, sum[:], time.Now().Unix(), + ) + if err != nil { + return "", fmt.Errorf("insert key: %w", err) + } + return tok, nil +} + +// Verify checks that token matches the stored hash for hostname. +func (s *Store) Verify(ctx context.Context, hostname, token string) (bool, error) { + var stored []byte + err := s.db.QueryRowContext(ctx, "SELECT key_hash FROM client_key WHERE hostname = ?", hostname).Scan(&stored) + if err == sql.ErrNoRows { + return false, nil + } + if err != nil { + return false, fmt.Errorf("lookup key: %w", err) + } + sum := sha256.Sum256([]byte(token)) + if len(stored) != len(sum) { + return false, nil + } + return subtle.ConstantTimeCompare(stored, sum[:]) == 1, nil +} diff --git a/internal/authkeys/store_test.go b/internal/authkeys/store_test.go new file mode 100644 index 0000000..8331a5b --- /dev/null +++ b/internal/authkeys/store_test.go @@ -0,0 +1,65 @@ +package authkeys + +import ( + "context" + "path/filepath" + "testing" +) + +func TestCreateVerifyReplace(t *testing.T) { + ctx := context.Background() + path := filepath.Join(t.TempDir(), "auth.db") + s, err := OpenStore(ctx, path) + if err != nil { + t.Fatal(err) + } + defer s.Close() + if err := s.EnsureSchema(ctx); err != nil { + t.Fatal(err) + } + n, err := s.KeyCount(ctx) + if err != nil || n != 0 { + t.Fatalf("KeyCount got %d err %v", n, err) + } + tok1, err := s.CreateKey(ctx, "host-a") + if err != nil { + t.Fatal(err) + } + if tok1 == "" { + t.Fatal("empty token") + } + n, err = s.KeyCount(ctx) + if err != nil || n != 1 { + t.Fatalf("KeyCount after create got %d err %v", n, err) + } + ok, err := s.Verify(ctx, "host-a", tok1) + if err != nil || !ok { + t.Fatalf("Verify tok1 got %v ok=%v", err, ok) + } + ok, err = s.Verify(ctx, "host-a", "wrong") + if err != nil || ok { + t.Fatalf("Verify wrong got %v ok=%v", err, ok) + } + tok2, err := s.CreateKey(ctx, "host-a") + if err != nil { + t.Fatal(err) + } + if tok2 == tok1 { + t.Fatal("expected new token after replace") + } + ok, err = s.Verify(ctx, "host-a", tok1) + if err != nil || ok { + t.Fatalf("old token should fail got %v ok=%v", err, ok) + } + ok, err = s.Verify(ctx, "host-a", tok2) + if err != nil || !ok { + t.Fatalf("new token should work got %v ok=%v", err, ok) + } +} + +func TestDefaultPath(t *testing.T) { + p := DefaultPath("/var/stats") + if filepath.Base(p) != "goprecords-auth.db" { + t.Fatalf("got %q", p) + } +} diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 7882758..7dd7cb6 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -9,6 +9,7 @@ import ( "os/signal" "syscall" + "codeberg.org/snonux/goprecords/internal/authkeys" "codeberg.org/snonux/goprecords/internal/daemon" "codeberg.org/snonux/goprecords/internal/goprecords" "codeberg.org/snonux/goprecords/internal/version" @@ -21,6 +22,12 @@ func Execute(args []string) error { fmt.Println(version.Version) return nil } + if len(args) >= 1 && (args[0] == "--create-client-key" || args[0] == "-create-client-key") { + if len(args) < 2 { + return fmt.Errorf("create-client-key: hostname required") + } + return runCreateClientKey(args[1], args[2:]) + } if len(args) > 0 && (args[0] == "-daemon" || args[0] == "--daemon") { return runDaemon(args[1:]) } @@ -138,6 +145,7 @@ func runDaemon(args []string) error { fs.SetOutput(os.Stdout) statsDir := fs.String("stats-dir", os.Getenv("GOPRECORDS_STATS_DIR"), "Uptimed stats directory (required; env GOPRECORDS_STATS_DIR)") listen := fs.String("listen", defaultListenFromEnv(), "TCP listen address (env GOPRECORDS_LISTEN, default :8080)") + authDB := fs.String("auth-db", "", "SQLite file for upload API keys (default: <stats-dir>/goprecords-auth.db)") if err := fs.Parse(args); err != nil { return err } @@ -148,9 +156,46 @@ func runDaemon(args []string) error { } ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - err := daemon.Run(ctx, daemon.Config{StatsDir: *statsDir, Addr: *listen}) + err := daemon.Run(ctx, daemon.Config{StatsDir: *statsDir, Addr: *listen, AuthDB: *authDB}) if err != nil && !errors.Is(err, context.Canceled) { return err } return nil } + +func runCreateClientKey(hostname string, args []string) error { + if hostname == "" { + return fmt.Errorf("create-client-key: hostname required") + } + fs := flag.NewFlagSet("create-client-key", flag.ExitOnError) + fs.SetOutput(os.Stderr) + statsDir := fs.String("stats-dir", "", "Uptimed stats directory (sets default auth-db path)") + authDB := fs.String("auth-db", "", "SQLite file for upload API keys (default: <stats-dir>/goprecords-auth.db)") + if err := fs.Parse(args); err != nil { + return err + } + authPath := *authDB + if authPath == "" { + if *statsDir == "" { + fmt.Fprintln(os.Stderr, "create-client-key: need -stats-dir or -auth-db") + fs.Usage() + return fmt.Errorf("missing -stats-dir or -auth-db") + } + authPath = authkeys.DefaultPath(*statsDir) + } + ctx := context.Background() + store, err := authkeys.OpenStore(ctx, authPath) + if err != nil { + return fmt.Errorf("open auth db: %w", err) + } + defer store.Close() + if err := store.EnsureSchema(ctx); err != nil { + return fmt.Errorf("schema: %w", err) + } + token, err := store.CreateKey(ctx, hostname) + if err != nil { + return fmt.Errorf("create key: %w", err) + } + fmt.Println(token) + return nil +} diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 9fa20e5..ff7b046 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -147,3 +147,30 @@ func TestStableSubcommandsStillRecognized(t *testing.T) { } } } + +func TestCreateClientKeyRequiresHostname(t *testing.T) { + err := Execute([]string{"--create-client-key"}) + if err == nil || !strings.Contains(err.Error(), "hostname") { + t.Fatalf("expected hostname error, got %v", err) + } +} + +func TestCreateClientKeyRequiresStatsOrAuthDB(t *testing.T) { + err := Execute([]string{"--create-client-key", "h1"}) + if err == nil || !strings.Contains(err.Error(), "stats-dir") { + t.Fatalf("expected stats-dir/auth-db error, got %v", err) + } +} + +func TestCreateClientKeyWritesToken(t *testing.T) { + dir := t.TempDir() + out := captureStdout(t, func() { + if err := Execute([]string{"--create-client-key", "mybox", "-stats-dir", dir}); err != nil { + t.Fatal(err) + } + }) + tok := strings.TrimSpace(out) + if len(tok) < 20 { + t.Fatalf("token too short %q", tok) + } +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 13e7311..8713797 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -10,24 +10,31 @@ import ( "os" "time" + "codeberg.org/snonux/goprecords/internal/authkeys" "codeberg.org/snonux/goprecords/internal/goprecords" ) type Config struct { StatsDir string Addr string + AuthDB string LogOutput io.Writer } -func routes(statsDir string) http.Handler { +func routes(statsDir string, store *authkeys.Store) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/health", health) mux.HandleFunc("/report", report(statsDir)) + mux.Handle("/upload/", uploadHandler(statsDir, store)) return mux } func Handler(statsDir string) http.Handler { - return routes(statsDir) + store, err := openAuthStore(context.Background(), statsDir, "") + if err != nil { + panic(err) + } + return routes(statsDir, store) } func logWriter(cfg Config) io.Writer { @@ -83,9 +90,14 @@ func Run(ctx context.Context, cfg Config) error { } w := logWriter(cfg) log, textHandler := newDaemonLogger(w) + store, err := openAuthStore(ctx, cfg.StatsDir, cfg.AuthDB) + if err != nil { + return fmt.Errorf("auth db: %w", err) + } + defer store.Close() srv := &http.Server{ Addr: cfg.Addr, - Handler: withAccessLog(log, routes(cfg.StatsDir)), + Handler: withAccessLog(log, routes(cfg.StatsDir, store)), ErrorLog: slog.NewLogLogger(textHandler, slog.LevelError), } log.Info("daemon_listen", "addr", cfg.Addr) diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 87b3dd8..d4133f2 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "os" "path/filepath" "strings" "testing" @@ -182,7 +183,13 @@ func TestAccessLogLineToWriter(t *testing.T) { var buf bytes.Buffer h := slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo}) log := slog.New(h) - srv := httptest.NewServer(withAccessLog(log, routes(t.TempDir()))) + statsDir := t.TempDir() + store, err := openAuthStore(context.Background(), statsDir, "") + if err != nil { + t.Fatal(err) + } + defer store.Close() + srv := httptest.NewServer(withAccessLog(log, routes(statsDir, store))) defer srv.Close() res, err := http.Get(srv.URL + "/health") if err != nil { @@ -197,3 +204,116 @@ func TestAccessLogLineToWriter(t *testing.T) { t.Fatalf("expected path and status in log, got %q", body) } } + +func TestUploadOpenWhenNoKeys(t *testing.T) { + statsDir := t.TempDir() + srv := httptest.NewServer(Handler(statsDir)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/myhost/txt", strings.NewReader("hello")) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusNoContent { + t.Fatalf("status %d", res.StatusCode) + } + b, err := os.ReadFile(filepath.Join(statsDir, "myhost.txt")) + if err != nil { + t.Fatal(err) + } + if string(b) != "hello" { + t.Fatalf("file %q", b) + } +} + +func TestUploadRequiresBearerWhenKeysExist(t *testing.T) { + statsDir := t.TempDir() + ctx := context.Background() + store, err := openAuthStore(ctx, statsDir, "") + if err != nil { + t.Fatal(err) + } + defer store.Close() + if _, err := store.CreateKey(ctx, "myhost"); err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes(statsDir, store)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/myhost/txt", strings.NewReader("x")) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusUnauthorized { + t.Fatalf("status %d want 401", res.StatusCode) + } +} + +func TestUploadWithValidBearer(t *testing.T) { + statsDir := t.TempDir() + ctx := context.Background() + store, err := openAuthStore(ctx, statsDir, "") + if err != nil { + t.Fatal(err) + } + defer store.Close() + tok, err := store.CreateKey(ctx, "myhost") + if err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes(statsDir, store)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/myhost/os.txt", strings.NewReader("os")) + req.Header.Set("Authorization", "Bearer "+tok) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusNoContent { + t.Fatalf("status %d", res.StatusCode) + } +} + +func TestUploadWrongHostForbidden(t *testing.T) { + statsDir := t.TempDir() + ctx := context.Background() + store, err := openAuthStore(ctx, statsDir, "") + if err != nil { + t.Fatal(err) + } + defer store.Close() + tok, err := store.CreateKey(ctx, "myhost") + if err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes(statsDir, store)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/other/txt", strings.NewReader("x")) + req.Header.Set("Authorization", "Bearer "+tok) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusForbidden { + t.Fatalf("status %d want 403", res.StatusCode) + } +} + +func TestUploadBadKind(t *testing.T) { + statsDir := t.TempDir() + srv := httptest.NewServer(Handler(statsDir)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/myhost/nope", strings.NewReader("x")) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusBadRequest { + t.Fatalf("status %d", res.StatusCode) + } +} diff --git a/internal/daemon/upload.go b/internal/daemon/upload.go new file mode 100644 index 0000000..be6c6d9 --- /dev/null +++ b/internal/daemon/upload.go @@ -0,0 +1,190 @@ +package daemon + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + "codeberg.org/snonux/goprecords/internal/authkeys" +) + +const maxUploadBytes = 8 << 20 + +var uploadKinds = map[string]string{ + "txt": ".txt", + "cur.txt": ".cur.txt", + "records": ".records", + "os.txt": ".os.txt", + "cpuinfo.txt": ".cpuinfo.txt", +} + +func uploadHandler(statsDir string, store *authkeys.Store) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + host, kind, ok := parseUploadPath(r.URL.Path) + if !ok { + http.Error(w, "bad path", http.StatusBadRequest) + return + } + ext, ok := uploadKinds[kind] + if !ok { + http.Error(w, "unknown file kind", http.StatusBadRequest) + return + } + ctx := r.Context() + nKeys, err := store.KeyCount(ctx) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if nKeys > 0 { + tok, ok := parseBearer(r.Header.Get("Authorization")) + if !ok || tok == "" { + w.Header().Set("WWW-Authenticate", `Bearer realm="upload"`) + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + valid, err := store.Verify(ctx, host, tok) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !valid { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + } + rel := host + ext + target := filepath.Join(statsDir, rel) + if !fileUnderDir(statsDir, target) { + http.Error(w, "bad path", http.StatusBadRequest) + return + } + if err := writeUploadBody(target, r.Body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusNoContent) + }) +} + +func parseUploadPath(path string) (host, kind string, ok bool) { + const prefix = "/upload/" + if !strings.HasPrefix(path, prefix) { + return "", "", false + } + rest := strings.TrimPrefix(path, prefix) + if rest == "" || strings.Contains(rest, "..") { + return "", "", false + } + i := strings.IndexByte(rest, '/') + if i <= 0 || i >= len(rest)-1 { + return "", "", false + } + host = rest[:i] + kind = rest[i+1:] + if !safeHostSegment(host) || strings.Contains(kind, "/") { + return "", "", false + } + return host, kind, true +} + +func safeHostSegment(s string) bool { + if s == "" || len(s) > 253 { + return false + } + for _, c := range s { + switch { + case c >= 'a' && c <= 'z': + case c >= 'A' && c <= 'Z': + case c >= '0' && c <= '9': + case c == '.' || c == '-' || c == '_': + default: + return false + } + } + return true +} + +func parseBearer(h string) (token string, ok bool) { + h = strings.TrimSpace(h) + const prefix = "Bearer " + if len(h) < len(prefix) { + return "", false + } + if !strings.EqualFold(h[:len(prefix)], prefix) { + return "", false + } + t := strings.TrimSpace(h[len(prefix):]) + return t, t != "" +} + +func fileUnderDir(dir, file string) bool { + absDir, err := filepath.Abs(dir) + if err != nil { + return false + } + absFile, err := filepath.Abs(file) + if err != nil { + return false + } + rel, err := filepath.Rel(absDir, absFile) + if err != nil { + return false + } + if rel == "." { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) +} + +func writeUploadBody(path string, body io.Reader) error { + tmp := path + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return fmt.Errorf("create temp: %w", err) + } + lr := &io.LimitedReader{R: body, N: maxUploadBytes + 1} + n, err := io.Copy(f, lr) + if err != nil { + f.Close() + os.Remove(tmp) + return fmt.Errorf("write: %w", err) + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return fmt.Errorf("close temp: %w", err) + } + if n > maxUploadBytes { + os.Remove(tmp) + return fmt.Errorf("body too large") + } + if err := os.Rename(tmp, path); err != nil { + os.Remove(tmp) + return fmt.Errorf("rename: %w", err) + } + return nil +} + +func openAuthStore(ctx context.Context, statsDir, authDB string) (*authkeys.Store, error) { + path := authDB + if path == "" { + path = authkeys.DefaultPath(statsDir) + } + s, err := authkeys.OpenStore(ctx, path) + if err != nil { + return nil, err + } + if err := s.EnsureSchema(ctx); err != nil { + s.Close() + return nil, err + } + return s, nil +} |
