summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-03 11:11:46 +0200
committerPaul Buetow <paul@buetow.org>2026-03-03 11:11:46 +0200
commit36286212ca5a6e7de85fd05338ca70194707841f (patch)
treeb08931749c2fa424e59029f5864802795201944e
parent2de007f9ef8ae2724b9fbe2808ee25cbfe4ca876 (diff)
Add auth-key fast reconnect integration coverage
-rw-r--r--integrationtests/authkey_test.go363
-rw-r--r--internal/clients/connectors/serverconnection.go13
-rw-r--r--internal/server/handlers/serverhandler.go11
-rw-r--r--internal/ssh/client/authmethods.go77
-rw-r--r--internal/ssh/client/authmethods_test.go100
-rw-r--r--internal/ssh/ssh.go50
6 files changed, 532 insertions, 82 deletions
diff --git a/integrationtests/authkey_test.go b/integrationtests/authkey_test.go
new file mode 100644
index 0000000..40e9ad7
--- /dev/null
+++ b/integrationtests/authkey_test.go
@@ -0,0 +1,363 @@
+package integrationtests
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+const (
+ authKeyFastPathLog = "Authorized by in-memory auth key store"
+ dcatExpectedFirstOutput = "1 Sat 2 Oct 13:46:45 EEST 2021"
+)
+
+func TestAuthKeyFastReconnectIntegration(t *testing.T) {
+ skipIfNotIntegrationTest(t)
+ cleanupTmpFiles(t)
+
+ t.Run("RegistrationFastPathAndFallback", testAuthKeyRegistrationFastPathAndFallback)
+ t.Run("TTLExpiry", testAuthKeyTTLExpiry)
+ t.Run("MaxKeysPerUser", testAuthKeyMaxKeysPerUser)
+ t.Run("NoAuthKeyFlag", testNoAuthKeyFlagDisablesFeature)
+}
+
+func testAuthKeyRegistrationFastPathAndFallback(t *testing.T) {
+ authKeyPath := createAuthKeyPair(t, "authkey-registration")
+ server := startAuthKeyServer(t, "")
+ defer server.Stop()
+
+ exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_registration_1.tmp", server.Address(), authKeyPath, false)
+ if err != nil || exitCode != 0 {
+ t.Fatalf("Expected first connection to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_registration_1.tmp")
+ waitForServerLogs()
+ if got := server.CountLogLinesContaining(authKeyFastPathLog); got != 0 {
+ t.Fatalf("Expected first connection to use fallback, fast-path count=%d", got)
+ }
+
+ exitCode, err = runDCatWithAuthKey(server.Context(), t, "authkey_registration_2.tmp", server.Address(), authKeyPath, false)
+ if err != nil || exitCode != 0 {
+ t.Fatalf("Expected second connection to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_registration_2.tmp")
+ waitForServerLogs()
+ if got := server.CountLogLinesContaining(authKeyFastPathLog); got < 1 {
+ t.Fatalf("Expected fast-path authorization after registration, fast-path count=%d", got)
+ }
+
+ server.Stop()
+ time.Sleep(300 * time.Millisecond)
+
+ restartedServer := startAuthKeyServer(t, "")
+ defer restartedServer.Stop()
+
+ exitCode, err = runDCatWithAuthKey(restartedServer.Context(), t, "authkey_registration_3.tmp", restartedServer.Address(), authKeyPath, false)
+ if err != nil || exitCode != 0 {
+ t.Fatalf("Expected fallback after restart to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_registration_3.tmp")
+ waitForServerLogs()
+ if got := restartedServer.CountLogLinesContaining(authKeyFastPathLog); got != 0 {
+ t.Fatalf("Expected no fast-path hit on first post-restart connection, fast-path count=%d", got)
+ }
+}
+
+func testAuthKeyTTLExpiry(t *testing.T) {
+ authKeyPath := createAuthKeyPair(t, "authkey-ttl")
+ ttlSeconds := 8
+ cfgFile := writeAuthKeyServerConfig(t, ttlSeconds, 5)
+ server := startAuthKeyServer(t, cfgFile)
+ defer server.Stop()
+
+ exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_ttl_1.tmp", server.Address(), authKeyPath, false)
+ if err != nil || exitCode != 0 {
+ t.Fatalf("Expected first connection to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_ttl_1.tmp")
+
+ exitCode, err = runDCatWithAuthKey(server.Context(), t, "authkey_ttl_2.tmp", server.Address(), authKeyPath, false)
+ if err != nil || exitCode != 0 {
+ t.Fatalf("Expected second connection to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_ttl_2.tmp")
+ waitForServerLogs()
+ fastPathCountAfterSecond := server.CountLogLinesContaining(authKeyFastPathLog)
+ if fastPathCountAfterSecond < 1 {
+ t.Fatalf("Expected fast-path hit before TTL expiry, count=%d", fastPathCountAfterSecond)
+ }
+
+ time.Sleep(time.Duration(ttlSeconds+1) * time.Second)
+ exitCode, err = runDCatWithAuthKey(server.Context(), t, "authkey_ttl_3.tmp", server.Address(), authKeyPath, false)
+ if err != nil || exitCode != 0 {
+ t.Fatalf("Expected fallback after TTL expiry to still connect, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_ttl_3.tmp")
+ waitForServerLogs()
+ fastPathCountAfterThird := server.CountLogLinesContaining(authKeyFastPathLog)
+ if fastPathCountAfterThird != fastPathCountAfterSecond {
+ t.Fatalf("Expected TTL-expired key to stop fast-path hits: before=%d after=%d",
+ fastPathCountAfterSecond, fastPathCountAfterThird)
+ }
+}
+
+func testAuthKeyMaxKeysPerUser(t *testing.T) {
+ authKeyOne := createAuthKeyPair(t, "authkey-max-one")
+ authKeyTwo := createAuthKeyPair(t, "authkey-max-two")
+ cfgFile := writeAuthKeyServerConfig(t, 3600, 1)
+ server := startAuthKeyServer(t, cfgFile)
+ defer server.Stop()
+
+ if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_max_1.tmp", server.Address(), authKeyOne, false); err != nil || exitCode != 0 {
+ t.Fatalf("Expected first key registration to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_max_1.tmp")
+ if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_max_2.tmp", server.Address(), authKeyTwo, false); err != nil || exitCode != 0 {
+ t.Fatalf("Expected second key registration to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_max_2.tmp")
+ waitForServerLogs()
+ initialFastPathCount := server.CountLogLinesContaining(authKeyFastPathLog)
+
+ if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_max_3.tmp", server.Address(), authKeyOne, false); err != nil || exitCode != 0 {
+ t.Fatalf("Expected first key connection (after max eviction) to succeed via fallback, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_max_3.tmp")
+ waitForServerLogs()
+ afterOldKeyCount := server.CountLogLinesContaining(authKeyFastPathLog)
+ if afterOldKeyCount != initialFastPathCount {
+ t.Fatalf("Expected evicted old key to avoid fast-path hit: before=%d after=%d",
+ initialFastPathCount, afterOldKeyCount)
+ }
+
+ if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_max_4.tmp", server.Address(), authKeyOne, false); err != nil || exitCode != 0 {
+ t.Fatalf("Expected re-registered first key to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_max_4.tmp")
+ waitForServerLogs()
+ afterNewKeyCount := server.CountLogLinesContaining(authKeyFastPathLog)
+ if afterNewKeyCount <= afterOldKeyCount {
+ t.Fatalf("Expected re-registered key to use fast-path: old-count=%d new-count=%d", afterOldKeyCount, afterNewKeyCount)
+ }
+}
+
+func testNoAuthKeyFlagDisablesFeature(t *testing.T) {
+ authKeyPath := createAuthKeyPair(t, "authkey-noauth")
+ server := startAuthKeyServer(t, "")
+ defer server.Stop()
+
+ if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_noauth_1.tmp", server.Address(), authKeyPath, true); err != nil || exitCode != 0 {
+ t.Fatalf("Expected first --no-auth-key connection to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_noauth_1.tmp")
+ if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_noauth_2.tmp", server.Address(), authKeyPath, true); err != nil || exitCode != 0 {
+ t.Fatalf("Expected second --no-auth-key connection to succeed, exit=%d err=%v", exitCode, err)
+ }
+ assertDCatSuccessfulOutput(t, "authkey_noauth_2.tmp")
+
+ waitForServerLogs()
+ if got := server.CountLogLinesContaining(authKeyFastPathLog); got != 0 {
+ t.Fatalf("Expected --no-auth-key to prevent fast-path registration, fast-path count=%d", got)
+ }
+}
+
+type authKeyServer struct {
+ ctx context.Context
+ cancel context.CancelFunc
+ addr string
+ logs *authKeyServerLogs
+}
+
+func (s *authKeyServer) Stop() {
+ s.cancel()
+}
+
+func (s *authKeyServer) Context() context.Context {
+ return s.ctx
+}
+
+func (s *authKeyServer) Address() string {
+ return s.addr
+}
+
+func (s *authKeyServer) CountLogLinesContaining(substring string) int {
+ return s.logs.countContaining(substring)
+}
+
+type authKeyServerLogs struct {
+ mu sync.Mutex
+ lines []string
+}
+
+func newAuthKeyServerLogs() *authKeyServerLogs {
+ return &authKeyServerLogs{
+ lines: make([]string, 0, 128),
+ }
+}
+
+func (l *authKeyServerLogs) append(line string) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.lines = append(l.lines, line)
+}
+
+func (l *authKeyServerLogs) countContaining(substring string) int {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ count := 0
+ for _, line := range l.lines {
+ if strings.Contains(line, substring) {
+ count++
+ }
+ }
+ return count
+}
+
+func startAuthKeyServer(t *testing.T, cfgFile string) *authKeyServer {
+ t.Helper()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ port := getUniquePortNumber()
+ args := []string{
+ "--cfg", "none",
+ "--logger", "stdout",
+ "--logLevel", "info",
+ "--bindAddress", "localhost",
+ "--port", fmt.Sprintf("%d", port),
+ }
+ if cfgFile != "" {
+ args = append(args, "--cfg", cfgFile)
+ }
+
+ stdoutCh, stderrCh, cmdErrCh, err := startCommand(ctx, t, "", "../dserver", args...)
+ if err != nil {
+ cancel()
+ t.Fatalf("Unable to start dserver: %v", err)
+ }
+
+ logs := newAuthKeyServerLogs()
+ go func() {
+ for {
+ select {
+ case line, ok := <-stdoutCh:
+ if ok {
+ logs.append(line)
+ }
+ case line, ok := <-stderrCh:
+ if ok {
+ logs.append(line)
+ }
+ case err := <-cmdErrCh:
+ if err != nil {
+ logs.append(err.Error())
+ }
+ return
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+
+ time.Sleep(500 * time.Millisecond)
+ return &authKeyServer{
+ ctx: ctx,
+ cancel: cancel,
+ addr: fmt.Sprintf("localhost:%d", port),
+ logs: logs,
+ }
+}
+
+func runDCatWithAuthKey(ctx context.Context, t *testing.T, outFile,
+ serverAddress, authKeyPath string, noAuthKey bool) (int, error) {
+ t.Helper()
+
+ args := []string{
+ "--plain",
+ "--cfg", "none",
+ "--servers", serverAddress,
+ "--files", "dcat1a.txt",
+ "--trustAllHosts",
+ "--noColor",
+ "--auth-key-path", authKeyPath,
+ }
+ if noAuthKey {
+ args = append(args, "--no-auth-key")
+ }
+
+ return runCommand(ctx, t, outFile, "../dcat", args...)
+}
+
+func assertDCatSuccessfulOutput(t *testing.T, outFile string) {
+ t.Helper()
+
+ outBytes, err := os.ReadFile(outFile)
+ if err != nil {
+ t.Fatalf("Unable to read dcat output file %s: %v", outFile, err)
+ }
+
+ output := string(outBytes)
+ if strings.Contains(output, "SSH handshake failed") {
+ t.Fatalf("Expected successful SSH connection, got handshake failure in %s:\n%s", outFile, output)
+ }
+ if !strings.Contains(output, dcatExpectedFirstOutput) {
+ t.Fatalf("Expected dcat output to contain %q in %s, got:\n%s", dcatExpectedFirstOutput, outFile, output)
+ }
+}
+
+func writeAuthKeyServerConfig(t *testing.T, ttlSeconds, maxPerUser int) string {
+ t.Helper()
+
+ cfgPath := filepath.Join(t.TempDir(), "authkey_server_config.json")
+ cfgContent := fmt.Sprintf(
+ `{"Server":{"AuthKeyEnabled":true,"AuthKeyTTLSeconds":%d,"AuthKeyMaxPerUser":%d}}`,
+ ttlSeconds, maxPerUser,
+ )
+ if err := os.WriteFile(cfgPath, []byte(cfgContent), 0600); err != nil {
+ t.Fatalf("Unable to write auth-key server config: %v", err)
+ }
+ return cfgPath
+}
+
+func createAuthKeyPair(t *testing.T, keyName string) string {
+ t.Helper()
+
+ keyPath := filepath.Join(t.TempDir(), keyName)
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatalf("Unable to generate private key: %v", err)
+ }
+
+ privateKeyBytes := pem.EncodeToMemory(&pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
+ })
+ if err := os.WriteFile(keyPath, privateKeyBytes, 0600); err != nil {
+ t.Fatalf("Unable to write private key: %v", err)
+ }
+
+ publicKey, err := gossh.NewPublicKey(&privateKey.PublicKey)
+ if err != nil {
+ t.Fatalf("Unable to generate public key: %v", err)
+ }
+ if err := os.WriteFile(keyPath+".pub", gossh.MarshalAuthorizedKey(publicKey), 0600); err != nil {
+ t.Fatalf("Unable to write public key: %v", err)
+ }
+
+ return keyPath
+}
+
+func waitForServerLogs() {
+ time.Sleep(300 * time.Millisecond)
+}
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index fbeb1bc..649fe30 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -224,18 +224,19 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc
}
}()
- // Send all commands to client.
+ if c.authKeyDisabled {
+ dlog.Client.Debug(c.server, "Skipping AUTHKEY registration because auth-key is disabled")
+ } else {
+ c.sendAuthKeyRegistrationCommand()
+ }
+
+ // Send all requested commands to the server.
for _, command := range c.commands {
dlog.Client.Debug(command)
if err := c.handler.SendMessage(command); err != nil {
dlog.Client.Debug(err)
}
}
- if c.authKeyDisabled {
- dlog.Client.Debug(c.server, "Skipping AUTHKEY registration because auth-key is disabled")
- } else {
- c.sendAuthKeyRegistrationCommand()
- }
if !c.throttlingDone {
dlog.Client.Debug(c.server, "Unthrottling connection (2)",
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 53ab4e3..5d5a78c 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -79,14 +79,17 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon
argc int, args []string, commandName string) {
dlog.Server.Debug(h.user, "Handling user command", argc, args)
+ shutdownOnCompletion := shouldShutdownOnCommandCompletion(commandName)
h.incrementActiveCommands()
commandFinished := func() {
activeCommands := h.decrementActiveCommands()
pendingFiles := atomic.LoadInt32(&h.pendingFiles)
dlog.Server.Debug(h.user, "Command finished", "activeCommands", activeCommands, "pendingFiles", pendingFiles)
- // Only shutdown if no active commands AND no pending files
- if activeCommands == 0 && pendingFiles == 0 {
+ // Only shutdown if no active commands AND no pending files.
+ // AUTHKEY is a session-side effect command and should not terminate the shell
+ // because user commands may still follow in the same session.
+ if shutdownOnCompletion && activeCommands == 0 && pendingFiles == 0 {
h.shutdown()
}
}
@@ -102,6 +105,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon
handler(ctx, ltx, argc, args, commandFinished)
}
+func shouldShutdownOnCommandCompletion(commandName string) bool {
+ return !strings.EqualFold(commandName, "AUTHKEY")
+}
+
func (h *ServerHandler) newCommandRegistry() map[string]commandHandler {
return map[string]commandHandler{
"grep": h.makeReadCommandHandler(omode.GrepClient, 1),
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index a414ade..7ac4d0c 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -12,8 +12,8 @@ import (
)
var (
- privateKeyAuthMethod = ssh.PrivateKey
- agentAuthMethod = ssh.AgentWithKeyIndex
+ privateKeySigner = ssh.PrivateKeySigner
+ agentSigners = ssh.AgentSignersWithKeyIndex
)
// InitSSHAuthMethods initialises all known SSH auth methods on the client side.
@@ -31,21 +31,6 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex)
}
-func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod {
- var sshAuthMethods []gossh.AuthMethod
- privateKeyPath := "./id_rsa"
-
- GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096)
- authMethod, err := ssh.PrivateKey(privateKeyPath)
- if err != nil {
- dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err)
- }
-
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", privateKeyPath)
- return sshAuthMethods
-}
-
func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) {
@@ -62,7 +47,10 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsFile)
if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
- return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback
+ if privateKeyPath == "" {
+ privateKeyPath = "./id_rsa"
+ }
+ GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096)
}
sshAuthMethods := collectKnownHostsAuthMethods(privateKeyPath, agentKeyIndex)
@@ -74,7 +62,15 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
}
func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []gossh.AuthMethod {
- var sshAuthMethods []gossh.AuthMethod
+ signers := collectKnownHostsSigners(privateKeyPath, agentKeyIndex)
+ if len(signers) == 0 {
+ return nil
+ }
+ return []gossh.AuthMethod{gossh.PublicKeys(signers...)}
+}
+
+func collectKnownHostsSigners(privateKeyPath string, agentKeyIndex int) []gossh.Signer {
+ var signers []gossh.Signer
home := os.Getenv("HOME")
defaultPrivateKeyPaths := []string{
@@ -83,13 +79,32 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go
home + "/.ssh/id_ecdsa",
home + "/.ssh/id_ed25519",
}
+ if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
+ defaultPrivateKeyPaths = append([]string{"./id_rsa"}, defaultPrivateKeyPaths...)
+ }
if privateKeyPath == "" {
privateKeyPath = defaultPrivateKeyPaths[0]
}
addedPrivateKeyPaths := make(map[string]bool, len(defaultPrivateKeyPaths)+1)
- addPrivateKeyAuthMethod := func(path string) {
+ addedPublicKeys := make(map[string]bool, len(defaultPrivateKeyPaths)+1)
+ addSigner := func(source string, signer gossh.Signer) {
+ if signer == nil {
+ return
+ }
+
+ pubKey := string(signer.PublicKey().Marshal())
+ if addedPublicKeys[pubKey] {
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Skipping duplicate signer", source)
+ return
+ }
+
+ addedPublicKeys[pubKey] = true
+ signers = append(signers, signer)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added signer", source)
+ }
+ addPrivateKeySigner := func(path string) {
if path == "" {
return
}
@@ -97,33 +112,33 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go
return
}
- authMethod, err := privateKeyAuthMethod(path)
+ signer, err := privateKeySigner(path)
if err != nil {
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", path, err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load private key signer", path, err)
return
}
- sshAuthMethods = append(sshAuthMethods, authMethod)
addedPrivateKeyPaths[path] = true
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", path)
+ addSigner(path, signer)
}
// First, the explicit auth key path (or default ~/.ssh/id_rsa).
- addPrivateKeyAuthMethod(privateKeyPath)
+ addPrivateKeySigner(privateKeyPath)
// Second, SSH agent (YubiKey-backed keys are typically exposed here).
- authMethod, err := agentAuthMethod(agentKeyIndex)
+ loadedAgentSigners, err := agentSigners(agentKeyIndex)
if err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH agent auth method")
+ for i, signer := range loadedAgentSigners {
+ addSigner(fmt.Sprintf("agent:%d:%d", agentKeyIndex, i), signer)
+ }
} else {
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load SSH agent signers", err)
}
// Third, additional default private key paths.
for _, path := range defaultPrivateKeyPaths {
- addPrivateKeyAuthMethod(path)
+ addPrivateKeySigner(path)
}
- return sshAuthMethods
+ return signers
}
diff --git a/internal/ssh/client/authmethods_test.go b/internal/ssh/client/authmethods_test.go
index 04751f5..e1e92b0 100644
--- a/internal/ssh/client/authmethods_test.go
+++ b/internal/ssh/client/authmethods_test.go
@@ -2,6 +2,7 @@ package client
import (
"fmt"
+ "io"
"reflect"
"testing"
@@ -10,42 +11,84 @@ import (
gossh "golang.org/x/crypto/ssh"
)
+type mockPublicKey struct {
+ id string
+}
+
+func (k *mockPublicKey) Type() string {
+ return "ssh-rsa"
+}
+
+func (k *mockPublicKey) Marshal() []byte {
+ return []byte(k.id)
+}
+
+func (k *mockPublicKey) Verify(_ []byte, _ *gossh.Signature) error {
+ return nil
+}
+
+type mockSigner struct {
+ key gossh.PublicKey
+}
+
+func newMockSigner(id string) gossh.Signer {
+ return &mockSigner{key: &mockPublicKey{id: id}}
+}
+
+func (s *mockSigner) PublicKey() gossh.PublicKey {
+ return s.key
+}
+
+func (s *mockSigner) Sign(_ io.Reader, _ []byte) (*gossh.Signature, error) {
+ return &gossh.Signature{
+ Format: "ssh-rsa",
+ Blob: []byte("sig"),
+ }, nil
+}
+
func TestCollectKnownHostsAuthMethodsOrder(t *testing.T) {
homeDir := "/tmp/dtail-auth-order"
t.Setenv("HOME", homeDir)
- originalPrivateKeyAuthMethod := privateKeyAuthMethod
- originalAgentAuthMethod := agentAuthMethod
+ originalPrivateKeySigner := privateKeySigner
+ originalAgentSigners := agentSigners
originalLogger := dlog.Client
dlog.Client = &dlog.DLog{}
t.Cleanup(func() {
- privateKeyAuthMethod = originalPrivateKeyAuthMethod
- agentAuthMethod = originalAgentAuthMethod
+ privateKeySigner = originalPrivateKeySigner
+ agentSigners = originalAgentSigners
dlog.Client = originalLogger
})
var callOrder []string
- successfulPrivateKeys := map[string]bool{
- "/custom/id_fast": true,
- homeDir + "/.ssh/id_rsa": true,
- homeDir + "/.ssh/id_dsa": true,
+ successfulPrivateKeys := map[string]gossh.Signer{
+ "/custom/id_fast": newMockSigner("custom"),
+ homeDir + "/.ssh/id_rsa": newMockSigner("default-rsa"),
+ homeDir + "/.ssh/id_dsa": newMockSigner("default-dsa"),
}
- privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
+ privateKeySigner = func(path string) (gossh.Signer, error) {
callOrder = append(callOrder, "private:"+path)
- if !successfulPrivateKeys[path] {
+ signer, found := successfulPrivateKeys[path]
+ if !found {
return nil, fmt.Errorf("missing private key: %s", path)
}
- return gossh.Password(path), nil
+ return signer, nil
}
- agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
+ agentSigners = func(keyIndex int) ([]gossh.Signer, error) {
callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
- return gossh.Password("agent"), nil
+ return []gossh.Signer{newMockSigner("agent")}, nil
}
methods := collectKnownHostsAuthMethods("/custom/id_fast", 7)
- if len(methods) != 4 {
- t.Fatalf("Expected 4 auth methods, got %d", len(methods))
+ if len(methods) != 1 {
+ t.Fatalf("Expected 1 auth method, got %d", len(methods))
+ }
+
+ callOrder = nil
+ signers := collectKnownHostsSigners("/custom/id_fast", 7)
+ if len(signers) != 4 {
+ t.Fatalf("Expected 4 signers, got %d", len(signers))
}
expectedOrder := []string{
@@ -65,32 +108,39 @@ func TestCollectKnownHostsAuthMethodsSkipsDuplicateDefaultPath(t *testing.T) {
homeDir := "/tmp/dtail-auth-dedupe"
t.Setenv("HOME", homeDir)
- originalPrivateKeyAuthMethod := privateKeyAuthMethod
- originalAgentAuthMethod := agentAuthMethod
+ originalPrivateKeySigner := privateKeySigner
+ originalAgentSigners := agentSigners
originalLogger := dlog.Client
dlog.Client = &dlog.DLog{}
t.Cleanup(func() {
- privateKeyAuthMethod = originalPrivateKeyAuthMethod
- agentAuthMethod = originalAgentAuthMethod
+ privateKeySigner = originalPrivateKeySigner
+ agentSigners = originalAgentSigners
dlog.Client = originalLogger
})
+ sharedSigner := newMockSigner("shared")
var callOrder []string
- privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
+ privateKeySigner = func(path string) (gossh.Signer, error) {
callOrder = append(callOrder, "private:"+path)
if path == homeDir+"/.ssh/id_rsa" {
- return gossh.Password(path), nil
+ return sharedSigner, nil
}
return nil, fmt.Errorf("missing private key: %s", path)
}
- agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
+ agentSigners = func(keyIndex int) ([]gossh.Signer, error) {
callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
- return gossh.Password("agent"), nil
+ return []gossh.Signer{sharedSigner}, nil
}
methods := collectKnownHostsAuthMethods(homeDir+"/.ssh/id_rsa", 2)
- if len(methods) != 2 {
- t.Fatalf("Expected 2 auth methods, got %d", len(methods))
+ if len(methods) != 1 {
+ t.Fatalf("Expected 1 auth method, got %d", len(methods))
+ }
+
+ callOrder = nil
+ signers := collectKnownHostsSigners(homeDir+"/.ssh/id_rsa", 2)
+ if len(signers) != 1 {
+ t.Fatalf("Expected duplicate keys to collapse to 1 signer, got %d", len(signers))
}
expectedOrder := []string{
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 7088e89..a191fd5 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -48,9 +48,9 @@ func Agent() (gossh.AuthMethod, error) {
return AgentWithKeyIndex(-1)
}
-// AgentWithKeyIndex used for SSH auth with a specific key index from the agent.
+// AgentSignersWithKeyIndex returns SSH agent signers.
// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used.
-func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
+func AgentSignersWithKeyIndex(keyIndex int) ([]gossh.Signer, error) {
// Use context-aware dialing for SSH agent connection (local Unix socket).
// 2-second timeout is reasonable for local socket connections.
dialer := &net.Dialer{
@@ -73,28 +73,33 @@ func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
dlog.Common.Debug("Public key", i, key)
}
+ signers, err := agentClient.Signers()
+ if err != nil {
+ return nil, fmt.Errorf("failed to load SSH agent signers: %w", err)
+ }
+
// If no specific key index requested, use all keys (backwards compatible default)
if keyIndex < 0 {
- return gossh.PublicKeysCallback(agentClient.Signers), nil
+ return signers, nil
}
// Use only the specified key index (0-based)
- if keyIndex >= len(keys) {
- return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys))
+ if keyIndex >= len(signers) {
+ return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers))
}
dlog.Common.Debug("Using SSH agent key at index", keyIndex)
- return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) {
- signers, err := agentClient.Signers()
- if err != nil {
- return nil, err
- }
- if keyIndex >= len(signers) {
- return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers))
- }
- // Return only the specified signer
- return []gossh.Signer{signers[keyIndex]}, nil
- }), nil
+ return []gossh.Signer{signers[keyIndex]}, nil
+}
+
+// AgentWithKeyIndex used for SSH auth with a specific key index from the agent.
+// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used.
+func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
+ signers, err := AgentSignersWithKeyIndex(keyIndex)
+ if err != nil {
+ return nil, err
+ }
+ return gossh.PublicKeys(signers...), nil
}
// EnterKeyPhrase is required to read phrase protected private keys.
@@ -108,8 +113,8 @@ func EnterKeyPhrase(keyFile string) []byte {
return phrase
}
-// KeyFile returns the key as a SSH auth method.
-func KeyFile(keyFile string) (gossh.AuthMethod, error) {
+// PrivateKeySigner returns an SSH signer from the provided private key file.
+func PrivateKeySigner(keyFile string) (gossh.Signer, error) {
buffer, err := os.ReadFile(keyFile)
if err != nil {
return nil, err
@@ -118,6 +123,15 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) {
if err != nil {
return nil, err
}
+ return key, nil
+}
+
+// KeyFile returns the key as a SSH auth method.
+func KeyFile(keyFile string) (gossh.AuthMethod, error) {
+ key, err := PrivateKeySigner(keyFile)
+ if err != nil {
+ return nil, err
+ }
// Key phrase support disabled as password will be printed to stdout!
/*