diff options
Diffstat (limited to 'internal/daemon')
| -rw-r--r-- | internal/daemon/daemon.go | 18 | ||||
| -rw-r--r-- | internal/daemon/daemon_test.go | 122 | ||||
| -rw-r--r-- | internal/daemon/upload.go | 190 |
3 files changed, 326 insertions, 4 deletions
diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 13e7311..8713797 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -10,24 +10,31 @@ import ( "os" "time" + "codeberg.org/snonux/goprecords/internal/authkeys" "codeberg.org/snonux/goprecords/internal/goprecords" ) type Config struct { StatsDir string Addr string + AuthDB string LogOutput io.Writer } -func routes(statsDir string) http.Handler { +func routes(statsDir string, store *authkeys.Store) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/health", health) mux.HandleFunc("/report", report(statsDir)) + mux.Handle("/upload/", uploadHandler(statsDir, store)) return mux } func Handler(statsDir string) http.Handler { - return routes(statsDir) + store, err := openAuthStore(context.Background(), statsDir, "") + if err != nil { + panic(err) + } + return routes(statsDir, store) } func logWriter(cfg Config) io.Writer { @@ -83,9 +90,14 @@ func Run(ctx context.Context, cfg Config) error { } w := logWriter(cfg) log, textHandler := newDaemonLogger(w) + store, err := openAuthStore(ctx, cfg.StatsDir, cfg.AuthDB) + if err != nil { + return fmt.Errorf("auth db: %w", err) + } + defer store.Close() srv := &http.Server{ Addr: cfg.Addr, - Handler: withAccessLog(log, routes(cfg.StatsDir)), + Handler: withAccessLog(log, routes(cfg.StatsDir, store)), ErrorLog: slog.NewLogLogger(textHandler, slog.LevelError), } log.Info("daemon_listen", "addr", cfg.Addr) diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 87b3dd8..d4133f2 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "os" "path/filepath" "strings" "testing" @@ -182,7 +183,13 @@ func TestAccessLogLineToWriter(t *testing.T) { var buf bytes.Buffer h := slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo}) log := slog.New(h) - srv := httptest.NewServer(withAccessLog(log, routes(t.TempDir()))) + statsDir := t.TempDir() + store, err := openAuthStore(context.Background(), statsDir, "") + if err != nil { + t.Fatal(err) + } + defer store.Close() + srv := httptest.NewServer(withAccessLog(log, routes(statsDir, store))) defer srv.Close() res, err := http.Get(srv.URL + "/health") if err != nil { @@ -197,3 +204,116 @@ func TestAccessLogLineToWriter(t *testing.T) { t.Fatalf("expected path and status in log, got %q", body) } } + +func TestUploadOpenWhenNoKeys(t *testing.T) { + statsDir := t.TempDir() + srv := httptest.NewServer(Handler(statsDir)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/myhost/txt", strings.NewReader("hello")) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusNoContent { + t.Fatalf("status %d", res.StatusCode) + } + b, err := os.ReadFile(filepath.Join(statsDir, "myhost.txt")) + if err != nil { + t.Fatal(err) + } + if string(b) != "hello" { + t.Fatalf("file %q", b) + } +} + +func TestUploadRequiresBearerWhenKeysExist(t *testing.T) { + statsDir := t.TempDir() + ctx := context.Background() + store, err := openAuthStore(ctx, statsDir, "") + if err != nil { + t.Fatal(err) + } + defer store.Close() + if _, err := store.CreateKey(ctx, "myhost"); err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes(statsDir, store)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/myhost/txt", strings.NewReader("x")) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusUnauthorized { + t.Fatalf("status %d want 401", res.StatusCode) + } +} + +func TestUploadWithValidBearer(t *testing.T) { + statsDir := t.TempDir() + ctx := context.Background() + store, err := openAuthStore(ctx, statsDir, "") + if err != nil { + t.Fatal(err) + } + defer store.Close() + tok, err := store.CreateKey(ctx, "myhost") + if err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes(statsDir, store)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/myhost/os.txt", strings.NewReader("os")) + req.Header.Set("Authorization", "Bearer "+tok) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusNoContent { + t.Fatalf("status %d", res.StatusCode) + } +} + +func TestUploadWrongHostForbidden(t *testing.T) { + statsDir := t.TempDir() + ctx := context.Background() + store, err := openAuthStore(ctx, statsDir, "") + if err != nil { + t.Fatal(err) + } + defer store.Close() + tok, err := store.CreateKey(ctx, "myhost") + if err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes(statsDir, store)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/other/txt", strings.NewReader("x")) + req.Header.Set("Authorization", "Bearer "+tok) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusForbidden { + t.Fatalf("status %d want 403", res.StatusCode) + } +} + +func TestUploadBadKind(t *testing.T) { + statsDir := t.TempDir() + srv := httptest.NewServer(Handler(statsDir)) + defer srv.Close() + req, _ := http.NewRequest(http.MethodPut, srv.URL+"/upload/myhost/nope", strings.NewReader("x")) + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusBadRequest { + t.Fatalf("status %d", res.StatusCode) + } +} diff --git a/internal/daemon/upload.go b/internal/daemon/upload.go new file mode 100644 index 0000000..be6c6d9 --- /dev/null +++ b/internal/daemon/upload.go @@ -0,0 +1,190 @@ +package daemon + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + "codeberg.org/snonux/goprecords/internal/authkeys" +) + +const maxUploadBytes = 8 << 20 + +var uploadKinds = map[string]string{ + "txt": ".txt", + "cur.txt": ".cur.txt", + "records": ".records", + "os.txt": ".os.txt", + "cpuinfo.txt": ".cpuinfo.txt", +} + +func uploadHandler(statsDir string, store *authkeys.Store) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + host, kind, ok := parseUploadPath(r.URL.Path) + if !ok { + http.Error(w, "bad path", http.StatusBadRequest) + return + } + ext, ok := uploadKinds[kind] + if !ok { + http.Error(w, "unknown file kind", http.StatusBadRequest) + return + } + ctx := r.Context() + nKeys, err := store.KeyCount(ctx) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if nKeys > 0 { + tok, ok := parseBearer(r.Header.Get("Authorization")) + if !ok || tok == "" { + w.Header().Set("WWW-Authenticate", `Bearer realm="upload"`) + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + valid, err := store.Verify(ctx, host, tok) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !valid { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + } + rel := host + ext + target := filepath.Join(statsDir, rel) + if !fileUnderDir(statsDir, target) { + http.Error(w, "bad path", http.StatusBadRequest) + return + } + if err := writeUploadBody(target, r.Body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusNoContent) + }) +} + +func parseUploadPath(path string) (host, kind string, ok bool) { + const prefix = "/upload/" + if !strings.HasPrefix(path, prefix) { + return "", "", false + } + rest := strings.TrimPrefix(path, prefix) + if rest == "" || strings.Contains(rest, "..") { + return "", "", false + } + i := strings.IndexByte(rest, '/') + if i <= 0 || i >= len(rest)-1 { + return "", "", false + } + host = rest[:i] + kind = rest[i+1:] + if !safeHostSegment(host) || strings.Contains(kind, "/") { + return "", "", false + } + return host, kind, true +} + +func safeHostSegment(s string) bool { + if s == "" || len(s) > 253 { + return false + } + for _, c := range s { + switch { + case c >= 'a' && c <= 'z': + case c >= 'A' && c <= 'Z': + case c >= '0' && c <= '9': + case c == '.' || c == '-' || c == '_': + default: + return false + } + } + return true +} + +func parseBearer(h string) (token string, ok bool) { + h = strings.TrimSpace(h) + const prefix = "Bearer " + if len(h) < len(prefix) { + return "", false + } + if !strings.EqualFold(h[:len(prefix)], prefix) { + return "", false + } + t := strings.TrimSpace(h[len(prefix):]) + return t, t != "" +} + +func fileUnderDir(dir, file string) bool { + absDir, err := filepath.Abs(dir) + if err != nil { + return false + } + absFile, err := filepath.Abs(file) + if err != nil { + return false + } + rel, err := filepath.Rel(absDir, absFile) + if err != nil { + return false + } + if rel == "." { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) +} + +func writeUploadBody(path string, body io.Reader) error { + tmp := path + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return fmt.Errorf("create temp: %w", err) + } + lr := &io.LimitedReader{R: body, N: maxUploadBytes + 1} + n, err := io.Copy(f, lr) + if err != nil { + f.Close() + os.Remove(tmp) + return fmt.Errorf("write: %w", err) + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return fmt.Errorf("close temp: %w", err) + } + if n > maxUploadBytes { + os.Remove(tmp) + return fmt.Errorf("body too large") + } + if err := os.Rename(tmp, path); err != nil { + os.Remove(tmp) + return fmt.Errorf("rename: %w", err) + } + return nil +} + +func openAuthStore(ctx context.Context, statsDir, authDB string) (*authkeys.Store, error) { + path := authDB + if path == "" { + path = authkeys.DefaultPath(statsDir) + } + s, err := authkeys.OpenStore(ctx, path) + if err != nil { + return nil, err + } + if err := s.EnsureSchema(ctx); err != nil { + s.Close() + return nil, err + } + return s, nil +} |
