summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-03 10:03:31 +0200
committerPaul Buetow <paul@buetow.org>2026-03-03 10:03:31 +0200
commitd0436c0040732592db861c6eebbf05a1d04e09f1 (patch)
tree24189e8fa9178201b6abe63c0365fcc637568bd1
parenta426a2f9f33b1125a05d3aac29e7b98afdc36a99 (diff)
feat(ssh-server): add in-memory auth key store
-rw-r--r--internal/ssh/server/authkeystore.go183
-rw-r--r--internal/ssh/server/authkeystore_test.go178
2 files changed, 361 insertions, 0 deletions
diff --git a/internal/ssh/server/authkeystore.go b/internal/ssh/server/authkeystore.go
new file mode 100644
index 0000000..45855ff
--- /dev/null
+++ b/internal/ssh/server/authkeystore.go
@@ -0,0 +1,183 @@
+package server
+
+import (
+ "sync"
+ "time"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+const (
+ defaultAuthKeyTTL = 24 * time.Hour
+ defaultAuthKeyMaxPerUser = 5
+)
+
+type authKeyEntry struct {
+ pubKey gossh.PublicKey
+ registeredAt time.Time
+}
+
+// AuthKeyStore is an in-memory, per-user cache of SSH public keys.
+type AuthKeyStore struct {
+ mu sync.RWMutex
+ keysByUser map[string][]authKeyEntry
+ ttl time.Duration
+ maxKeysPerUser int
+ now func() time.Time
+}
+
+// NewAuthKeyStore builds a thread-safe auth key store.
+func NewAuthKeyStore(ttl time.Duration, maxKeysPerUser int) *AuthKeyStore {
+ return newAuthKeyStoreWithClock(ttl, maxKeysPerUser, time.Now)
+}
+
+func newAuthKeyStoreWithClock(ttl time.Duration, maxKeysPerUser int,
+ nowFn func() time.Time) *AuthKeyStore {
+
+ if ttl <= 0 {
+ ttl = defaultAuthKeyTTL
+ }
+ if maxKeysPerUser <= 0 {
+ maxKeysPerUser = defaultAuthKeyMaxPerUser
+ }
+ if nowFn == nil {
+ nowFn = time.Now
+ }
+
+ return &AuthKeyStore{
+ keysByUser: make(map[string][]authKeyEntry),
+ ttl: ttl,
+ maxKeysPerUser: maxKeysPerUser,
+ now: nowFn,
+ }
+}
+
+// Add stores or refreshes a key for a user.
+func (s *AuthKeyStore) Add(user string, pubKey gossh.PublicKey) {
+ if user == "" || pubKey == nil {
+ return
+ }
+
+ now := s.now()
+ offeredKey := marshalKey(pubKey)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ userEntries := s.pruneExpiredLocked(user, now)
+
+ newEntries := make([]authKeyEntry, 0, len(userEntries)+1)
+ for _, entry := range userEntries {
+ if marshalKey(entry.pubKey) == offeredKey {
+ continue
+ }
+ newEntries = append(newEntries, entry)
+ }
+
+ newEntries = append(newEntries, authKeyEntry{
+ pubKey: pubKey,
+ registeredAt: now,
+ })
+ if len(newEntries) > s.maxKeysPerUser {
+ newEntries = newEntries[len(newEntries)-s.maxKeysPerUser:]
+ }
+
+ s.keysByUser[user] = newEntries
+}
+
+// Has returns true if a non-expired key exists for a user.
+func (s *AuthKeyStore) Has(user string, pubKey gossh.PublicKey) bool {
+ if user == "" || pubKey == nil {
+ return false
+ }
+
+ now := s.now()
+ offeredKey := marshalKey(pubKey)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ userEntries := s.pruneExpiredLocked(user, now)
+ for _, entry := range userEntries {
+ if marshalKey(entry.pubKey) == offeredKey {
+ return true
+ }
+ }
+
+ return false
+}
+
+// Remove deletes a key for a user if it exists.
+func (s *AuthKeyStore) Remove(user string, pubKey gossh.PublicKey) {
+ if user == "" || pubKey == nil {
+ return
+ }
+
+ offeredKey := marshalKey(pubKey)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ userEntries := s.pruneExpiredLocked(user, s.now())
+ if len(userEntries) == 0 {
+ return
+ }
+
+ remaining := make([]authKeyEntry, 0, len(userEntries))
+ for _, entry := range userEntries {
+ if marshalKey(entry.pubKey) == offeredKey {
+ continue
+ }
+ remaining = append(remaining, entry)
+ }
+
+ if len(remaining) == 0 {
+ delete(s.keysByUser, user)
+ return
+ }
+
+ s.keysByUser[user] = remaining
+}
+
+func (s *AuthKeyStore) pruneExpiredLocked(user string, now time.Time) []authKeyEntry {
+ userEntries, ok := s.keysByUser[user]
+ if !ok || len(userEntries) == 0 {
+ delete(s.keysByUser, user)
+ return nil
+ }
+
+ hasExpiredEntries := false
+ for _, entry := range userEntries {
+ if s.expired(entry, now) {
+ hasExpiredEntries = true
+ break
+ }
+ }
+ if !hasExpiredEntries {
+ return userEntries
+ }
+
+ activeEntries := make([]authKeyEntry, 0, len(userEntries))
+ for _, entry := range userEntries {
+ if s.expired(entry, now) {
+ continue
+ }
+ activeEntries = append(activeEntries, entry)
+ }
+
+ if len(activeEntries) == 0 {
+ delete(s.keysByUser, user)
+ return nil
+ }
+
+ s.keysByUser[user] = activeEntries
+ return activeEntries
+}
+
+func (s *AuthKeyStore) expired(entry authKeyEntry, now time.Time) bool {
+ return !entry.registeredAt.Add(s.ttl).After(now)
+}
+
+func marshalKey(pubKey gossh.PublicKey) string {
+ return string(pubKey.Marshal())
+}
diff --git a/internal/ssh/server/authkeystore_test.go b/internal/ssh/server/authkeystore_test.go
new file mode 100644
index 0000000..056db7b
--- /dev/null
+++ b/internal/ssh/server/authkeystore_test.go
@@ -0,0 +1,178 @@
+package server
+
+import (
+ "crypto/ed25519"
+ "sync"
+ "testing"
+ "time"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+func TestAuthKeyStoreAddHasRemove(t *testing.T) {
+ store := NewAuthKeyStore(time.Hour, 5)
+ key := testPublicKey(t, 1)
+
+ if store.Has("alice", key) {
+ t.Fatalf("Store should not contain key before add")
+ }
+
+ store.Add("alice", key)
+ if !store.Has("alice", key) {
+ t.Fatalf("Store should contain key after add")
+ }
+
+ store.Remove("alice", key)
+ if store.Has("alice", key) {
+ t.Fatalf("Store should not contain key after remove")
+ }
+}
+
+func TestAuthKeyStoreHasExpiresKeysLazily(t *testing.T) {
+ now := time.Date(2026, 3, 3, 10, 0, 0, 0, time.UTC)
+ store := newAuthKeyStoreWithClock(10*time.Second, 5, func() time.Time { return now })
+ key := testPublicKey(t, 2)
+
+ store.Add("alice", key)
+ if !store.Has("alice", key) {
+ t.Fatalf("Store should contain fresh key")
+ }
+
+ now = now.Add(11 * time.Second)
+ if store.Has("alice", key) {
+ t.Fatalf("Store should expire key when ttl is exceeded")
+ }
+
+ store.mu.RLock()
+ defer store.mu.RUnlock()
+ if len(store.keysByUser["alice"]) != 0 {
+ t.Fatalf("Expired entries should be removed on Has call")
+ }
+}
+
+func TestAuthKeyStoreEnforcesPerUserKeyLimit(t *testing.T) {
+ now := time.Date(2026, 3, 3, 10, 0, 0, 0, time.UTC)
+ store := newAuthKeyStoreWithClock(time.Hour, 2, func() time.Time { return now })
+
+ keyOne := testPublicKey(t, 3)
+ keyTwo := testPublicKey(t, 4)
+ keyThree := testPublicKey(t, 5)
+
+ store.Add("alice", keyOne)
+ now = now.Add(1 * time.Second)
+ store.Add("alice", keyTwo)
+ now = now.Add(1 * time.Second)
+ store.Add("alice", keyThree)
+
+ if store.Has("alice", keyOne) {
+ t.Fatalf("Oldest key should be evicted once max key limit is reached")
+ }
+ if !store.Has("alice", keyTwo) {
+ t.Fatalf("Second key should remain in store")
+ }
+ if !store.Has("alice", keyThree) {
+ t.Fatalf("Newest key should remain in store")
+ }
+}
+
+func TestAuthKeyStoreAddRefreshesExistingKey(t *testing.T) {
+ now := time.Date(2026, 3, 3, 10, 0, 0, 0, time.UTC)
+ store := newAuthKeyStoreWithClock(10*time.Second, 5, func() time.Time { return now })
+ key := testPublicKey(t, 6)
+
+ store.Add("alice", key)
+ now = now.Add(9 * time.Second)
+ store.Add("alice", key)
+
+ now = now.Add(5 * time.Second)
+ if !store.Has("alice", key) {
+ t.Fatalf("Key should stay valid after it is refreshed")
+ }
+
+ now = now.Add(6 * time.Second)
+ if store.Has("alice", key) {
+ t.Fatalf("Refreshed key should expire once ttl is exceeded from latest add")
+ }
+}
+
+func TestAuthKeyStoreUserIsolation(t *testing.T) {
+ store := NewAuthKeyStore(time.Hour, 5)
+ key := testPublicKey(t, 7)
+
+ store.Add("alice", key)
+ if store.Has("bob", key) {
+ t.Fatalf("Key lookup must be isolated by user")
+ }
+}
+
+func TestAuthKeyStoreIgnoresInvalidInput(t *testing.T) {
+ store := NewAuthKeyStore(time.Hour, 5)
+ key := testPublicKey(t, 8)
+
+ store.Add("", key)
+ store.Add("alice", nil)
+ store.Remove("", key)
+ store.Remove("alice", nil)
+
+ if store.Has("", key) {
+ t.Fatalf("Empty user should not match")
+ }
+ if store.Has("alice", nil) {
+ t.Fatalf("Nil key should not match")
+ }
+}
+
+func TestAuthKeyStoreConcurrentAccess(t *testing.T) {
+ store := NewAuthKeyStore(time.Hour, 5)
+ users := []string{"alice", "bob", "carol"}
+ keys := []gossh.PublicKey{
+ testPublicKey(t, 11),
+ testPublicKey(t, 12),
+ testPublicKey(t, 13),
+ testPublicKey(t, 14),
+ }
+
+ var wg sync.WaitGroup
+ for worker := 0; worker < 32; worker++ {
+ wg.Add(1)
+ go func(workerID int) {
+ defer wg.Done()
+
+ user := users[workerID%len(users)]
+ for i := 0; i < 200; i++ {
+ key := keys[(workerID+i)%len(keys)]
+ store.Add(user, key)
+ _ = store.Has(user, key)
+ if i%3 == 0 {
+ store.Remove(user, key)
+ }
+ }
+ }(worker)
+ }
+ wg.Wait()
+
+ store.mu.RLock()
+ defer store.mu.RUnlock()
+ for user, userEntries := range store.keysByUser {
+ if len(userEntries) > store.maxKeysPerUser {
+ t.Fatalf("User %s exceeded max key limit: %d", user, len(userEntries))
+ }
+ }
+}
+
+func testPublicKey(t *testing.T, seedByte byte) gossh.PublicKey {
+ t.Helper()
+
+ seed := make([]byte, ed25519.SeedSize)
+ for i := range seed {
+ seed[i] = seedByte
+ }
+
+ privateKey := ed25519.NewKeyFromSeed(seed)
+ publicKey, err := gossh.NewPublicKey(privateKey.Public())
+ if err != nil {
+ t.Fatalf("Unable to build ssh public key: %s", err.Error())
+ }
+
+ return publicKey
+}