diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-03 10:03:31 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-03 10:03:31 +0200 |
| commit | d0436c0040732592db861c6eebbf05a1d04e09f1 (patch) | |
| tree | 24189e8fa9178201b6abe63c0365fcc637568bd1 | |
| parent | a426a2f9f33b1125a05d3aac29e7b98afdc36a99 (diff) | |
feat(ssh-server): add in-memory auth key store
| -rw-r--r-- | internal/ssh/server/authkeystore.go | 183 | ||||
| -rw-r--r-- | internal/ssh/server/authkeystore_test.go | 178 |
2 files changed, 361 insertions, 0 deletions
diff --git a/internal/ssh/server/authkeystore.go b/internal/ssh/server/authkeystore.go new file mode 100644 index 0000000..45855ff --- /dev/null +++ b/internal/ssh/server/authkeystore.go @@ -0,0 +1,183 @@ +package server + +import ( + "sync" + "time" + + gossh "golang.org/x/crypto/ssh" +) + +const ( + defaultAuthKeyTTL = 24 * time.Hour + defaultAuthKeyMaxPerUser = 5 +) + +type authKeyEntry struct { + pubKey gossh.PublicKey + registeredAt time.Time +} + +// AuthKeyStore is an in-memory, per-user cache of SSH public keys. +type AuthKeyStore struct { + mu sync.RWMutex + keysByUser map[string][]authKeyEntry + ttl time.Duration + maxKeysPerUser int + now func() time.Time +} + +// NewAuthKeyStore builds a thread-safe auth key store. +func NewAuthKeyStore(ttl time.Duration, maxKeysPerUser int) *AuthKeyStore { + return newAuthKeyStoreWithClock(ttl, maxKeysPerUser, time.Now) +} + +func newAuthKeyStoreWithClock(ttl time.Duration, maxKeysPerUser int, + nowFn func() time.Time) *AuthKeyStore { + + if ttl <= 0 { + ttl = defaultAuthKeyTTL + } + if maxKeysPerUser <= 0 { + maxKeysPerUser = defaultAuthKeyMaxPerUser + } + if nowFn == nil { + nowFn = time.Now + } + + return &AuthKeyStore{ + keysByUser: make(map[string][]authKeyEntry), + ttl: ttl, + maxKeysPerUser: maxKeysPerUser, + now: nowFn, + } +} + +// Add stores or refreshes a key for a user. +func (s *AuthKeyStore) Add(user string, pubKey gossh.PublicKey) { + if user == "" || pubKey == nil { + return + } + + now := s.now() + offeredKey := marshalKey(pubKey) + + s.mu.Lock() + defer s.mu.Unlock() + + userEntries := s.pruneExpiredLocked(user, now) + + newEntries := make([]authKeyEntry, 0, len(userEntries)+1) + for _, entry := range userEntries { + if marshalKey(entry.pubKey) == offeredKey { + continue + } + newEntries = append(newEntries, entry) + } + + newEntries = append(newEntries, authKeyEntry{ + pubKey: pubKey, + registeredAt: now, + }) + if len(newEntries) > s.maxKeysPerUser { + newEntries = newEntries[len(newEntries)-s.maxKeysPerUser:] + } + + s.keysByUser[user] = newEntries +} + +// Has returns true if a non-expired key exists for a user. +func (s *AuthKeyStore) Has(user string, pubKey gossh.PublicKey) bool { + if user == "" || pubKey == nil { + return false + } + + now := s.now() + offeredKey := marshalKey(pubKey) + + s.mu.Lock() + defer s.mu.Unlock() + + userEntries := s.pruneExpiredLocked(user, now) + for _, entry := range userEntries { + if marshalKey(entry.pubKey) == offeredKey { + return true + } + } + + return false +} + +// Remove deletes a key for a user if it exists. +func (s *AuthKeyStore) Remove(user string, pubKey gossh.PublicKey) { + if user == "" || pubKey == nil { + return + } + + offeredKey := marshalKey(pubKey) + + s.mu.Lock() + defer s.mu.Unlock() + + userEntries := s.pruneExpiredLocked(user, s.now()) + if len(userEntries) == 0 { + return + } + + remaining := make([]authKeyEntry, 0, len(userEntries)) + for _, entry := range userEntries { + if marshalKey(entry.pubKey) == offeredKey { + continue + } + remaining = append(remaining, entry) + } + + if len(remaining) == 0 { + delete(s.keysByUser, user) + return + } + + s.keysByUser[user] = remaining +} + +func (s *AuthKeyStore) pruneExpiredLocked(user string, now time.Time) []authKeyEntry { + userEntries, ok := s.keysByUser[user] + if !ok || len(userEntries) == 0 { + delete(s.keysByUser, user) + return nil + } + + hasExpiredEntries := false + for _, entry := range userEntries { + if s.expired(entry, now) { + hasExpiredEntries = true + break + } + } + if !hasExpiredEntries { + return userEntries + } + + activeEntries := make([]authKeyEntry, 0, len(userEntries)) + for _, entry := range userEntries { + if s.expired(entry, now) { + continue + } + activeEntries = append(activeEntries, entry) + } + + if len(activeEntries) == 0 { + delete(s.keysByUser, user) + return nil + } + + s.keysByUser[user] = activeEntries + return activeEntries +} + +func (s *AuthKeyStore) expired(entry authKeyEntry, now time.Time) bool { + return !entry.registeredAt.Add(s.ttl).After(now) +} + +func marshalKey(pubKey gossh.PublicKey) string { + return string(pubKey.Marshal()) +} diff --git a/internal/ssh/server/authkeystore_test.go b/internal/ssh/server/authkeystore_test.go new file mode 100644 index 0000000..056db7b --- /dev/null +++ b/internal/ssh/server/authkeystore_test.go @@ -0,0 +1,178 @@ +package server + +import ( + "crypto/ed25519" + "sync" + "testing" + "time" + + gossh "golang.org/x/crypto/ssh" +) + +func TestAuthKeyStoreAddHasRemove(t *testing.T) { + store := NewAuthKeyStore(time.Hour, 5) + key := testPublicKey(t, 1) + + if store.Has("alice", key) { + t.Fatalf("Store should not contain key before add") + } + + store.Add("alice", key) + if !store.Has("alice", key) { + t.Fatalf("Store should contain key after add") + } + + store.Remove("alice", key) + if store.Has("alice", key) { + t.Fatalf("Store should not contain key after remove") + } +} + +func TestAuthKeyStoreHasExpiresKeysLazily(t *testing.T) { + now := time.Date(2026, 3, 3, 10, 0, 0, 0, time.UTC) + store := newAuthKeyStoreWithClock(10*time.Second, 5, func() time.Time { return now }) + key := testPublicKey(t, 2) + + store.Add("alice", key) + if !store.Has("alice", key) { + t.Fatalf("Store should contain fresh key") + } + + now = now.Add(11 * time.Second) + if store.Has("alice", key) { + t.Fatalf("Store should expire key when ttl is exceeded") + } + + store.mu.RLock() + defer store.mu.RUnlock() + if len(store.keysByUser["alice"]) != 0 { + t.Fatalf("Expired entries should be removed on Has call") + } +} + +func TestAuthKeyStoreEnforcesPerUserKeyLimit(t *testing.T) { + now := time.Date(2026, 3, 3, 10, 0, 0, 0, time.UTC) + store := newAuthKeyStoreWithClock(time.Hour, 2, func() time.Time { return now }) + + keyOne := testPublicKey(t, 3) + keyTwo := testPublicKey(t, 4) + keyThree := testPublicKey(t, 5) + + store.Add("alice", keyOne) + now = now.Add(1 * time.Second) + store.Add("alice", keyTwo) + now = now.Add(1 * time.Second) + store.Add("alice", keyThree) + + if store.Has("alice", keyOne) { + t.Fatalf("Oldest key should be evicted once max key limit is reached") + } + if !store.Has("alice", keyTwo) { + t.Fatalf("Second key should remain in store") + } + if !store.Has("alice", keyThree) { + t.Fatalf("Newest key should remain in store") + } +} + +func TestAuthKeyStoreAddRefreshesExistingKey(t *testing.T) { + now := time.Date(2026, 3, 3, 10, 0, 0, 0, time.UTC) + store := newAuthKeyStoreWithClock(10*time.Second, 5, func() time.Time { return now }) + key := testPublicKey(t, 6) + + store.Add("alice", key) + now = now.Add(9 * time.Second) + store.Add("alice", key) + + now = now.Add(5 * time.Second) + if !store.Has("alice", key) { + t.Fatalf("Key should stay valid after it is refreshed") + } + + now = now.Add(6 * time.Second) + if store.Has("alice", key) { + t.Fatalf("Refreshed key should expire once ttl is exceeded from latest add") + } +} + +func TestAuthKeyStoreUserIsolation(t *testing.T) { + store := NewAuthKeyStore(time.Hour, 5) + key := testPublicKey(t, 7) + + store.Add("alice", key) + if store.Has("bob", key) { + t.Fatalf("Key lookup must be isolated by user") + } +} + +func TestAuthKeyStoreIgnoresInvalidInput(t *testing.T) { + store := NewAuthKeyStore(time.Hour, 5) + key := testPublicKey(t, 8) + + store.Add("", key) + store.Add("alice", nil) + store.Remove("", key) + store.Remove("alice", nil) + + if store.Has("", key) { + t.Fatalf("Empty user should not match") + } + if store.Has("alice", nil) { + t.Fatalf("Nil key should not match") + } +} + +func TestAuthKeyStoreConcurrentAccess(t *testing.T) { + store := NewAuthKeyStore(time.Hour, 5) + users := []string{"alice", "bob", "carol"} + keys := []gossh.PublicKey{ + testPublicKey(t, 11), + testPublicKey(t, 12), + testPublicKey(t, 13), + testPublicKey(t, 14), + } + + var wg sync.WaitGroup + for worker := 0; worker < 32; worker++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + + user := users[workerID%len(users)] + for i := 0; i < 200; i++ { + key := keys[(workerID+i)%len(keys)] + store.Add(user, key) + _ = store.Has(user, key) + if i%3 == 0 { + store.Remove(user, key) + } + } + }(worker) + } + wg.Wait() + + store.mu.RLock() + defer store.mu.RUnlock() + for user, userEntries := range store.keysByUser { + if len(userEntries) > store.maxKeysPerUser { + t.Fatalf("User %s exceeded max key limit: %d", user, len(userEntries)) + } + } +} + +func testPublicKey(t *testing.T, seedByte byte) gossh.PublicKey { + t.Helper() + + seed := make([]byte, ed25519.SeedSize) + for i := range seed { + seed[i] = seedByte + } + + privateKey := ed25519.NewKeyFromSeed(seed) + publicKey, err := gossh.NewPublicKey(privateKey.Public()) + if err != nil { + t.Fatalf("Unable to build ssh public key: %s", err.Error()) + } + + return publicKey +} |
