summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-08 09:32:13 +0200
committerPaul Buetow <paul@buetow.org>2026-03-08 09:32:13 +0200
commit91b83a9ffcabf7264888cf84b95f08b8cc88c832 (patch)
tree009b7bded9db99dcb02e3a55314c4b624304bdba /internal
parent2007054d77b5bc40c943a9fd64874e850c750f2d (diff)
task: scope auth key dependencies to server instances (task 375)
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/connectors/serverless.go2
-rw-r--r--internal/server/handlers/authkeycommand_test.go8
-rw-r--r--internal/server/handlers/serverhandler.go32
-rw-r--r--internal/server/server.go21
-rw-r--r--internal/ssh/server/publickeycallback.go48
-rw-r--r--internal/ssh/server/publickeycallback_test.go8
6 files changed, 87 insertions, 32 deletions
diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go
index 74cd9e6..f4c4e9e 100644
--- a/internal/clients/connectors/serverless.go
+++ b/internal/clients/connectors/serverless.go
@@ -8,6 +8,7 @@ import (
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/dlog"
serverHandlers "github.com/mimecast/dtail/internal/server/handlers"
+ sshserver "github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
)
@@ -76,6 +77,7 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro
make(chan struct{}, config.Server.MaxConcurrentCats),
make(chan struct{}, config.Server.MaxConcurrentTails),
config.Server,
+ sshserver.AuthKeys(),
)
}
diff --git a/internal/server/handlers/authkeycommand_test.go b/internal/server/handlers/authkeycommand_test.go
index f510038..a454e94 100644
--- a/internal/server/handlers/authkeycommand_test.go
+++ b/internal/server/handlers/authkeycommand_test.go
@@ -33,11 +33,10 @@ func TestHandleAuthKeyCommandSuccess(t *testing.T) {
if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY OK\n" {
t.Fatalf("Unexpected response: %q", message)
}
- if !sshserver.AuthKeys().Has(handler.user.Name, key) {
+ if !handler.authKeyStore.Has(handler.user.Name, key) {
t.Fatalf("Expected key to be stored for user")
}
-
- sshserver.AuthKeys().Remove(handler.user.Name, key)
+ handler.authKeyStore.Remove(handler.user.Name, key)
}
func TestHandleAuthKeyCommandFeatureDisabled(t *testing.T) {
@@ -51,7 +50,7 @@ func TestHandleAuthKeyCommandFeatureDisabled(t *testing.T) {
if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY ERR feature disabled\n" {
t.Fatalf("Unexpected response: %q", message)
}
- if sshserver.AuthKeys().Has(handler.user.Name, key) {
+ if handler.authKeyStore.Has(handler.user.Name, key) {
t.Fatalf("Expected no key to be stored while feature is disabled")
}
}
@@ -84,6 +83,7 @@ func newAuthKeyTestHandler(userName string, authKeyEnabled bool) *ServerHandler
serverCfg: &config.ServerConfig{
AuthKeyEnabled: authKeyEnabled,
},
+ authKeyStore: sshserver.NewAuthKeyStore(time.Hour, 5),
}
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 078fd27..732cc06 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -23,11 +23,12 @@ import (
// This handler implements the handler of the SSH server.
type ServerHandler struct {
baseHandler
- catLimiter chan struct{}
- tailLimiter chan struct{}
- serverCfg *config.ServerConfig
- regex string
- commands map[string]commandHandler
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ serverCfg *config.ServerConfig
+ authKeyStore *sshserver.AuthKeyStore
+ regex string
+ commands map[string]commandHandler
// Track pending files waiting for limiter slots
pendingFiles int32
}
@@ -38,7 +39,8 @@ var _ Handler = (*ServerHandler)(nil)
// NewServerHandler returns the server handler.
func NewServerHandler(user *user.User, catLimiter,
- tailLimiter chan struct{}, serverCfg *config.ServerConfig) *ServerHandler {
+ tailLimiter chan struct{}, serverCfg *config.ServerConfig,
+ authKeyStore *sshserver.AuthKeyStore) *ServerHandler {
dlog.Server.Debug(user, "Creating new server handler")
if serverCfg == nil {
@@ -55,10 +57,14 @@ func NewServerHandler(user *user.User, catLimiter,
user: user,
codec: newProtocolCodec(user),
},
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- serverCfg: serverCfg,
- regex: ".",
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ serverCfg: serverCfg,
+ authKeyStore: authKeyStore,
+ regex: ".",
+ }
+ if h.authKeyStore == nil {
+ h.authKeyStore = sshserver.AuthKeys()
}
h.handleCommandCb = h.handleUserCommand
h.commands = h.newCommandRegistry()
@@ -180,6 +186,10 @@ func (h *ServerHandler) handleAuthKeyCommand(_ context.Context, _ lcontext.LCont
return
}
- sshserver.AuthKeys().Add(h.user.Name, pubKey)
+ if h.authKeyStore == nil {
+ h.sendln(h.serverMessages, "AUTHKEY ERR internal key store unavailable")
+ return
+ }
+ h.authKeyStore.Add(h.user.Name, pubKey)
h.sendln(h.serverMessages, "AUTHKEY OK")
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 943defa..72094ef 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -37,6 +37,8 @@ type Server struct {
cont *continuous
// Authentication strategies keyed by SSH username.
authStrategies map[string]authStrategy
+ // In-memory auth key cache for fast reconnect.
+ authKeyStore *server.AuthKeyStore
}
type authStrategy func(*user.User, string, string) bool
@@ -48,7 +50,6 @@ func New(cfg config.RuntimeConfig) *Server {
}
dlog.Server.Info("Starting server", version.String())
- server.ConfigureAuthKeyStore(cfg.Server.AuthKeyTTLSeconds, cfg.Server.AuthKeyMaxPerUser)
s := Server{
cfg: cfg,
@@ -64,11 +65,19 @@ func New(cfg config.RuntimeConfig) *Server {
tailLimiter: make(chan struct{}, cfg.Server.MaxConcurrentTails),
sched: newScheduler(cfg),
cont: newContinuous(cfg),
+ authKeyStore: server.NewAuthKeyStore(
+ time.Duration(cfg.Server.AuthKeyTTLSeconds)*time.Second,
+ cfg.Server.AuthKeyMaxPerUser,
+ ),
}
s.authStrategies = s.newAuthStrategies()
s.sshServerConfig.PasswordCallback = s.Callback
- s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback
+ s.sshServerConfig.PublicKeyCallback = server.NewPublicKeyCallback(
+ cfg.Server.AuthKeyEnabled,
+ cfg.Common.CacheDir,
+ s.authKeyStore,
+ )
private, err := gossh.ParsePrivateKey(server.PrivateHostKey())
if err != nil {
@@ -222,7 +231,13 @@ func (s *Server) handleShellRequest(ctx context.Context, sshConn gossh.Conn,
case config.HealthUser:
handler = handlers.NewHealthHandler(user)
default:
- handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.cfg.Server)
+ handler = handlers.NewServerHandler(
+ user,
+ s.catLimiter,
+ s.tailLimiter,
+ s.cfg.Server,
+ s.authKeyStore,
+ )
}
terminate := func() {
diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go
index c4624f4..ccf9111 100644
--- a/internal/ssh/server/publickeycallback.go
+++ b/internal/ssh/server/publickeycallback.go
@@ -17,20 +17,44 @@ import (
func PublicKeyCallback(c gossh.ConnMetadata,
offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) {
+ authKeyEnabled := config.Server != nil && config.Server.AuthKeyEnabled
+ cacheDir := ""
+ if config.Common != nil {
+ cacheDir = config.Common.CacheDir
+ }
+ return publicKeyCallback(c, offeredPubKey, authKeyEnabled, cacheDir, authKeyStore)
+}
+
+// NewPublicKeyCallback creates an instance-scoped SSH public key callback.
+// It avoids relying on package-level mutable configuration/state.
+func NewPublicKeyCallback(authKeyEnabled bool, cacheDir string,
+ keyStore *AuthKeyStore) func(gossh.ConnMetadata, gossh.PublicKey) (*gossh.Permissions, error) {
+
+ if keyStore == nil {
+ keyStore = authKeyStore
+ }
+ return func(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) {
+ return publicKeyCallback(c, offeredPubKey, authKeyEnabled, cacheDir, keyStore)
+ }
+}
+
+func publicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey,
+ authKeyEnabled bool, cacheDir string, keyStore *AuthKeyStore) (*gossh.Permissions, error) {
+
user, err := user.New(c.User(), c.RemoteAddr().String())
if err != nil {
return nil, err
}
dlog.Server.Info(user, "Incoming authorization")
- if config.Server != nil && config.Server.AuthKeyEnabled {
- if permissions := authKeyStorePermissions(user.Name, offeredPubKey); permissions != nil {
+ if authKeyEnabled {
+ if permissions := authKeyStorePermissions(keyStore, user.Name, offeredPubKey); permissions != nil {
dlog.Server.Info(user, "Authorized by in-memory auth key store")
return permissions, nil
}
}
- authorizedKeysFile, err := authorizedKeysFile(user)
+ authorizedKeysFile, err := authorizedKeysFile(user, cacheDir)
if err != nil {
return nil, err
}
@@ -69,8 +93,10 @@ func verifyAuthorizedKeys(user *user.User, authorizedKeysBytes []byte,
return nil, fmt.Errorf("%s|public key of user not authorized", user)
}
-func authKeyStorePermissions(userName string, offeredPubKey gossh.PublicKey) *gossh.Permissions {
- if !authKeyStore.Has(userName, offeredPubKey) {
+func authKeyStorePermissions(keyStore *AuthKeyStore, userName string,
+ offeredPubKey gossh.PublicKey) *gossh.Permissions {
+
+ if keyStore == nil || !keyStore.Has(userName, offeredPubKey) {
return nil
}
@@ -83,7 +109,7 @@ func permissionsFromPublicKey(offeredPubKey gossh.PublicKey) *gossh.Permissions
}
}
-func authorizedKeysFile(user *user.User) (string, error) {
+func authorizedKeysFile(user *user.User, cacheDir string) (string, error) {
if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
// In this case, we expect a pub key in the current directory.
return "./id_rsa.pub", nil
@@ -95,10 +121,12 @@ func authorizedKeysFile(user *user.User) (string, error) {
}
// Check for cached version in the dserver directory.
- authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd,
- config.Common.CacheDir, user.Name)
- if _, err = os.Stat(authorizedKeysFile); err == nil {
- return authorizedKeysFile, nil
+ var authorizedKeysFile string
+ if cacheDir != "" {
+ authorizedKeysFile = fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, cacheDir, user.Name)
+ if _, err = os.Stat(authorizedKeysFile); err == nil {
+ return authorizedKeysFile, nil
+ }
}
// As the last option, check the regular SSH path.
diff --git a/internal/ssh/server/publickeycallback_test.go b/internal/ssh/server/publickeycallback_test.go
index 7ded4f3..97baa72 100644
--- a/internal/ssh/server/publickeycallback_test.go
+++ b/internal/ssh/server/publickeycallback_test.go
@@ -16,13 +16,13 @@ func TestAuthKeyStorePermissions(t *testing.T) {
key := testPublicKey(t, 21)
- if permissions := authKeyStorePermissions("alice", key); permissions != nil {
+ if permissions := authKeyStorePermissions(authKeyStore, "alice", key); permissions != nil {
t.Fatalf("Expected nil permissions when no key is cached")
}
authKeyStore.Add("alice", key)
- permissions := authKeyStorePermissions("alice", key)
+ permissions := authKeyStorePermissions(authKeyStore, "alice", key)
if permissions == nil {
t.Fatalf("Expected permissions when key is cached")
}
@@ -30,12 +30,12 @@ func TestAuthKeyStorePermissions(t *testing.T) {
t.Fatalf("Unexpected fingerprint: %s", fingerprint)
}
- if permissions := authKeyStorePermissions("bob", key); permissions != nil {
+ if permissions := authKeyStorePermissions(authKeyStore, "bob", key); permissions != nil {
t.Fatalf("Expected nil permissions for different user")
}
unknownKey := testPublicKey(t, 22)
- if permissions := authKeyStorePermissions("alice", unknownKey); permissions != nil {
+ if permissions := authKeyStorePermissions(authKeyStore, "alice", unknownKey); permissions != nil {
t.Fatalf("Expected nil permissions for unknown key")
}
}