From 36286212ca5a6e7de85fd05338ca70194707841f Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Tue, 3 Mar 2026 11:11:46 +0200 Subject: Add auth-key fast reconnect integration coverage --- integrationtests/authkey_test.go | 363 ++++++++++++++++++++++++ internal/clients/connectors/serverconnection.go | 13 +- internal/server/handlers/serverhandler.go | 11 +- internal/ssh/client/authmethods.go | 77 +++-- internal/ssh/client/authmethods_test.go | 100 +++++-- internal/ssh/ssh.go | 50 ++-- 6 files changed, 532 insertions(+), 82 deletions(-) create mode 100644 integrationtests/authkey_test.go diff --git a/integrationtests/authkey_test.go b/integrationtests/authkey_test.go new file mode 100644 index 0000000..40e9ad7 --- /dev/null +++ b/integrationtests/authkey_test.go @@ -0,0 +1,363 @@ +package integrationtests + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + gossh "golang.org/x/crypto/ssh" +) + +const ( + authKeyFastPathLog = "Authorized by in-memory auth key store" + dcatExpectedFirstOutput = "1 Sat 2 Oct 13:46:45 EEST 2021" +) + +func TestAuthKeyFastReconnectIntegration(t *testing.T) { + skipIfNotIntegrationTest(t) + cleanupTmpFiles(t) + + t.Run("RegistrationFastPathAndFallback", testAuthKeyRegistrationFastPathAndFallback) + t.Run("TTLExpiry", testAuthKeyTTLExpiry) + t.Run("MaxKeysPerUser", testAuthKeyMaxKeysPerUser) + t.Run("NoAuthKeyFlag", testNoAuthKeyFlagDisablesFeature) +} + +func testAuthKeyRegistrationFastPathAndFallback(t *testing.T) { + authKeyPath := createAuthKeyPair(t, "authkey-registration") + server := startAuthKeyServer(t, "") + defer server.Stop() + + exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_registration_1.tmp", server.Address(), authKeyPath, false) + if err != nil || exitCode != 0 { + t.Fatalf("Expected first connection to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_registration_1.tmp") + waitForServerLogs() + if got := server.CountLogLinesContaining(authKeyFastPathLog); got != 0 { + t.Fatalf("Expected first connection to use fallback, fast-path count=%d", got) + } + + exitCode, err = runDCatWithAuthKey(server.Context(), t, "authkey_registration_2.tmp", server.Address(), authKeyPath, false) + if err != nil || exitCode != 0 { + t.Fatalf("Expected second connection to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_registration_2.tmp") + waitForServerLogs() + if got := server.CountLogLinesContaining(authKeyFastPathLog); got < 1 { + t.Fatalf("Expected fast-path authorization after registration, fast-path count=%d", got) + } + + server.Stop() + time.Sleep(300 * time.Millisecond) + + restartedServer := startAuthKeyServer(t, "") + defer restartedServer.Stop() + + exitCode, err = runDCatWithAuthKey(restartedServer.Context(), t, "authkey_registration_3.tmp", restartedServer.Address(), authKeyPath, false) + if err != nil || exitCode != 0 { + t.Fatalf("Expected fallback after restart to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_registration_3.tmp") + waitForServerLogs() + if got := restartedServer.CountLogLinesContaining(authKeyFastPathLog); got != 0 { + t.Fatalf("Expected no fast-path hit on first post-restart connection, fast-path count=%d", got) + } +} + +func testAuthKeyTTLExpiry(t *testing.T) { + authKeyPath := createAuthKeyPair(t, "authkey-ttl") + ttlSeconds := 8 + cfgFile := writeAuthKeyServerConfig(t, ttlSeconds, 5) + server := startAuthKeyServer(t, cfgFile) + defer server.Stop() + + exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_ttl_1.tmp", server.Address(), authKeyPath, false) + if err != nil || exitCode != 0 { + t.Fatalf("Expected first connection to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_ttl_1.tmp") + + exitCode, err = runDCatWithAuthKey(server.Context(), t, "authkey_ttl_2.tmp", server.Address(), authKeyPath, false) + if err != nil || exitCode != 0 { + t.Fatalf("Expected second connection to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_ttl_2.tmp") + waitForServerLogs() + fastPathCountAfterSecond := server.CountLogLinesContaining(authKeyFastPathLog) + if fastPathCountAfterSecond < 1 { + t.Fatalf("Expected fast-path hit before TTL expiry, count=%d", fastPathCountAfterSecond) + } + + time.Sleep(time.Duration(ttlSeconds+1) * time.Second) + exitCode, err = runDCatWithAuthKey(server.Context(), t, "authkey_ttl_3.tmp", server.Address(), authKeyPath, false) + if err != nil || exitCode != 0 { + t.Fatalf("Expected fallback after TTL expiry to still connect, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_ttl_3.tmp") + waitForServerLogs() + fastPathCountAfterThird := server.CountLogLinesContaining(authKeyFastPathLog) + if fastPathCountAfterThird != fastPathCountAfterSecond { + t.Fatalf("Expected TTL-expired key to stop fast-path hits: before=%d after=%d", + fastPathCountAfterSecond, fastPathCountAfterThird) + } +} + +func testAuthKeyMaxKeysPerUser(t *testing.T) { + authKeyOne := createAuthKeyPair(t, "authkey-max-one") + authKeyTwo := createAuthKeyPair(t, "authkey-max-two") + cfgFile := writeAuthKeyServerConfig(t, 3600, 1) + server := startAuthKeyServer(t, cfgFile) + defer server.Stop() + + if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_max_1.tmp", server.Address(), authKeyOne, false); err != nil || exitCode != 0 { + t.Fatalf("Expected first key registration to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_max_1.tmp") + if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_max_2.tmp", server.Address(), authKeyTwo, false); err != nil || exitCode != 0 { + t.Fatalf("Expected second key registration to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_max_2.tmp") + waitForServerLogs() + initialFastPathCount := server.CountLogLinesContaining(authKeyFastPathLog) + + if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_max_3.tmp", server.Address(), authKeyOne, false); err != nil || exitCode != 0 { + t.Fatalf("Expected first key connection (after max eviction) to succeed via fallback, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_max_3.tmp") + waitForServerLogs() + afterOldKeyCount := server.CountLogLinesContaining(authKeyFastPathLog) + if afterOldKeyCount != initialFastPathCount { + t.Fatalf("Expected evicted old key to avoid fast-path hit: before=%d after=%d", + initialFastPathCount, afterOldKeyCount) + } + + if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_max_4.tmp", server.Address(), authKeyOne, false); err != nil || exitCode != 0 { + t.Fatalf("Expected re-registered first key to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_max_4.tmp") + waitForServerLogs() + afterNewKeyCount := server.CountLogLinesContaining(authKeyFastPathLog) + if afterNewKeyCount <= afterOldKeyCount { + t.Fatalf("Expected re-registered key to use fast-path: old-count=%d new-count=%d", afterOldKeyCount, afterNewKeyCount) + } +} + +func testNoAuthKeyFlagDisablesFeature(t *testing.T) { + authKeyPath := createAuthKeyPair(t, "authkey-noauth") + server := startAuthKeyServer(t, "") + defer server.Stop() + + if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_noauth_1.tmp", server.Address(), authKeyPath, true); err != nil || exitCode != 0 { + t.Fatalf("Expected first --no-auth-key connection to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_noauth_1.tmp") + if exitCode, err := runDCatWithAuthKey(server.Context(), t, "authkey_noauth_2.tmp", server.Address(), authKeyPath, true); err != nil || exitCode != 0 { + t.Fatalf("Expected second --no-auth-key connection to succeed, exit=%d err=%v", exitCode, err) + } + assertDCatSuccessfulOutput(t, "authkey_noauth_2.tmp") + + waitForServerLogs() + if got := server.CountLogLinesContaining(authKeyFastPathLog); got != 0 { + t.Fatalf("Expected --no-auth-key to prevent fast-path registration, fast-path count=%d", got) + } +} + +type authKeyServer struct { + ctx context.Context + cancel context.CancelFunc + addr string + logs *authKeyServerLogs +} + +func (s *authKeyServer) Stop() { + s.cancel() +} + +func (s *authKeyServer) Context() context.Context { + return s.ctx +} + +func (s *authKeyServer) Address() string { + return s.addr +} + +func (s *authKeyServer) CountLogLinesContaining(substring string) int { + return s.logs.countContaining(substring) +} + +type authKeyServerLogs struct { + mu sync.Mutex + lines []string +} + +func newAuthKeyServerLogs() *authKeyServerLogs { + return &authKeyServerLogs{ + lines: make([]string, 0, 128), + } +} + +func (l *authKeyServerLogs) append(line string) { + l.mu.Lock() + defer l.mu.Unlock() + l.lines = append(l.lines, line) +} + +func (l *authKeyServerLogs) countContaining(substring string) int { + l.mu.Lock() + defer l.mu.Unlock() + + count := 0 + for _, line := range l.lines { + if strings.Contains(line, substring) { + count++ + } + } + return count +} + +func startAuthKeyServer(t *testing.T, cfgFile string) *authKeyServer { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + port := getUniquePortNumber() + args := []string{ + "--cfg", "none", + "--logger", "stdout", + "--logLevel", "info", + "--bindAddress", "localhost", + "--port", fmt.Sprintf("%d", port), + } + if cfgFile != "" { + args = append(args, "--cfg", cfgFile) + } + + stdoutCh, stderrCh, cmdErrCh, err := startCommand(ctx, t, "", "../dserver", args...) + if err != nil { + cancel() + t.Fatalf("Unable to start dserver: %v", err) + } + + logs := newAuthKeyServerLogs() + go func() { + for { + select { + case line, ok := <-stdoutCh: + if ok { + logs.append(line) + } + case line, ok := <-stderrCh: + if ok { + logs.append(line) + } + case err := <-cmdErrCh: + if err != nil { + logs.append(err.Error()) + } + return + case <-ctx.Done(): + return + } + } + }() + + time.Sleep(500 * time.Millisecond) + return &authKeyServer{ + ctx: ctx, + cancel: cancel, + addr: fmt.Sprintf("localhost:%d", port), + logs: logs, + } +} + +func runDCatWithAuthKey(ctx context.Context, t *testing.T, outFile, + serverAddress, authKeyPath string, noAuthKey bool) (int, error) { + t.Helper() + + args := []string{ + "--plain", + "--cfg", "none", + "--servers", serverAddress, + "--files", "dcat1a.txt", + "--trustAllHosts", + "--noColor", + "--auth-key-path", authKeyPath, + } + if noAuthKey { + args = append(args, "--no-auth-key") + } + + return runCommand(ctx, t, outFile, "../dcat", args...) +} + +func assertDCatSuccessfulOutput(t *testing.T, outFile string) { + t.Helper() + + outBytes, err := os.ReadFile(outFile) + if err != nil { + t.Fatalf("Unable to read dcat output file %s: %v", outFile, err) + } + + output := string(outBytes) + if strings.Contains(output, "SSH handshake failed") { + t.Fatalf("Expected successful SSH connection, got handshake failure in %s:\n%s", outFile, output) + } + if !strings.Contains(output, dcatExpectedFirstOutput) { + t.Fatalf("Expected dcat output to contain %q in %s, got:\n%s", dcatExpectedFirstOutput, outFile, output) + } +} + +func writeAuthKeyServerConfig(t *testing.T, ttlSeconds, maxPerUser int) string { + t.Helper() + + cfgPath := filepath.Join(t.TempDir(), "authkey_server_config.json") + cfgContent := fmt.Sprintf( + `{"Server":{"AuthKeyEnabled":true,"AuthKeyTTLSeconds":%d,"AuthKeyMaxPerUser":%d}}`, + ttlSeconds, maxPerUser, + ) + if err := os.WriteFile(cfgPath, []byte(cfgContent), 0600); err != nil { + t.Fatalf("Unable to write auth-key server config: %v", err) + } + return cfgPath +} + +func createAuthKeyPair(t *testing.T, keyName string) string { + t.Helper() + + keyPath := filepath.Join(t.TempDir(), keyName) + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Unable to generate private key: %v", err) + } + + privateKeyBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + if err := os.WriteFile(keyPath, privateKeyBytes, 0600); err != nil { + t.Fatalf("Unable to write private key: %v", err) + } + + publicKey, err := gossh.NewPublicKey(&privateKey.PublicKey) + if err != nil { + t.Fatalf("Unable to generate public key: %v", err) + } + if err := os.WriteFile(keyPath+".pub", gossh.MarshalAuthorizedKey(publicKey), 0600); err != nil { + t.Fatalf("Unable to write public key: %v", err) + } + + return keyPath +} + +func waitForServerLogs() { + time.Sleep(300 * time.Millisecond) +} diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index fbeb1bc..649fe30 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -224,18 +224,19 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc } }() - // Send all commands to client. + if c.authKeyDisabled { + dlog.Client.Debug(c.server, "Skipping AUTHKEY registration because auth-key is disabled") + } else { + c.sendAuthKeyRegistrationCommand() + } + + // Send all requested commands to the server. for _, command := range c.commands { dlog.Client.Debug(command) if err := c.handler.SendMessage(command); err != nil { dlog.Client.Debug(err) } } - if c.authKeyDisabled { - dlog.Client.Debug(c.server, "Skipping AUTHKEY registration because auth-key is disabled") - } else { - c.sendAuthKeyRegistrationCommand() - } if !c.throttlingDone { dlog.Client.Debug(c.server, "Unthrottling connection (2)", diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 53ab4e3..5d5a78c 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -79,14 +79,17 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon argc int, args []string, commandName string) { dlog.Server.Debug(h.user, "Handling user command", argc, args) + shutdownOnCompletion := shouldShutdownOnCommandCompletion(commandName) h.incrementActiveCommands() commandFinished := func() { activeCommands := h.decrementActiveCommands() pendingFiles := atomic.LoadInt32(&h.pendingFiles) dlog.Server.Debug(h.user, "Command finished", "activeCommands", activeCommands, "pendingFiles", pendingFiles) - // Only shutdown if no active commands AND no pending files - if activeCommands == 0 && pendingFiles == 0 { + // Only shutdown if no active commands AND no pending files. + // AUTHKEY is a session-side effect command and should not terminate the shell + // because user commands may still follow in the same session. + if shutdownOnCompletion && activeCommands == 0 && pendingFiles == 0 { h.shutdown() } } @@ -102,6 +105,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon handler(ctx, ltx, argc, args, commandFinished) } +func shouldShutdownOnCommandCompletion(commandName string) bool { + return !strings.EqualFold(commandName, "AUTHKEY") +} + func (h *ServerHandler) newCommandRegistry() map[string]commandHandler { return map[string]commandHandler{ "grep": h.makeReadCommandHandler(omode.GrepClient, 1), diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index a414ade..7ac4d0c 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -12,8 +12,8 @@ import ( ) var ( - privateKeyAuthMethod = ssh.PrivateKey - agentAuthMethod = ssh.AgentWithKeyIndex + privateKeySigner = ssh.PrivateKeySigner + agentSigners = ssh.AgentSignersWithKeyIndex ) // InitSSHAuthMethods initialises all known SSH auth methods on the client side. @@ -31,21 +31,6 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex) } -func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { - var sshAuthMethods []gossh.AuthMethod - privateKeyPath := "./id_rsa" - - GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096) - authMethod, err := ssh.PrivateKey(privateKeyPath) - if err != nil { - dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err) - } - - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", privateKeyPath) - return sshAuthMethods -} - func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { @@ -62,7 +47,10 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsFile) if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { - return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback + if privateKeyPath == "" { + privateKeyPath = "./id_rsa" + } + GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096) } sshAuthMethods := collectKnownHostsAuthMethods(privateKeyPath, agentKeyIndex) @@ -74,7 +62,15 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, } func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []gossh.AuthMethod { - var sshAuthMethods []gossh.AuthMethod + signers := collectKnownHostsSigners(privateKeyPath, agentKeyIndex) + if len(signers) == 0 { + return nil + } + return []gossh.AuthMethod{gossh.PublicKeys(signers...)} +} + +func collectKnownHostsSigners(privateKeyPath string, agentKeyIndex int) []gossh.Signer { + var signers []gossh.Signer home := os.Getenv("HOME") defaultPrivateKeyPaths := []string{ @@ -83,13 +79,32 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go home + "/.ssh/id_ecdsa", home + "/.ssh/id_ed25519", } + if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { + defaultPrivateKeyPaths = append([]string{"./id_rsa"}, defaultPrivateKeyPaths...) + } if privateKeyPath == "" { privateKeyPath = defaultPrivateKeyPaths[0] } addedPrivateKeyPaths := make(map[string]bool, len(defaultPrivateKeyPaths)+1) - addPrivateKeyAuthMethod := func(path string) { + addedPublicKeys := make(map[string]bool, len(defaultPrivateKeyPaths)+1) + addSigner := func(source string, signer gossh.Signer) { + if signer == nil { + return + } + + pubKey := string(signer.PublicKey().Marshal()) + if addedPublicKeys[pubKey] { + dlog.Client.Debug("initKnownHostsAuthMethods", "Skipping duplicate signer", source) + return + } + + addedPublicKeys[pubKey] = true + signers = append(signers, signer) + dlog.Client.Debug("initKnownHostsAuthMethods", "Added signer", source) + } + addPrivateKeySigner := func(path string) { if path == "" { return } @@ -97,33 +112,33 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go return } - authMethod, err := privateKeyAuthMethod(path) + signer, err := privateKeySigner(path) if err != nil { - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", path, err) + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load private key signer", path, err) return } - sshAuthMethods = append(sshAuthMethods, authMethod) addedPrivateKeyPaths[path] = true - dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", path) + addSigner(path, signer) } // First, the explicit auth key path (or default ~/.ssh/id_rsa). - addPrivateKeyAuthMethod(privateKeyPath) + addPrivateKeySigner(privateKeyPath) // Second, SSH agent (YubiKey-backed keys are typically exposed here). - authMethod, err := agentAuthMethod(agentKeyIndex) + loadedAgentSigners, err := agentSigners(agentKeyIndex) if err == nil { - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH agent auth method") + for i, signer := range loadedAgentSigners { + addSigner(fmt.Sprintf("agent:%d:%d", agentKeyIndex, i), signer) + } } else { - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load SSH agent signers", err) } // Third, additional default private key paths. for _, path := range defaultPrivateKeyPaths { - addPrivateKeyAuthMethod(path) + addPrivateKeySigner(path) } - return sshAuthMethods + return signers } diff --git a/internal/ssh/client/authmethods_test.go b/internal/ssh/client/authmethods_test.go index 04751f5..e1e92b0 100644 --- a/internal/ssh/client/authmethods_test.go +++ b/internal/ssh/client/authmethods_test.go @@ -2,6 +2,7 @@ package client import ( "fmt" + "io" "reflect" "testing" @@ -10,42 +11,84 @@ import ( gossh "golang.org/x/crypto/ssh" ) +type mockPublicKey struct { + id string +} + +func (k *mockPublicKey) Type() string { + return "ssh-rsa" +} + +func (k *mockPublicKey) Marshal() []byte { + return []byte(k.id) +} + +func (k *mockPublicKey) Verify(_ []byte, _ *gossh.Signature) error { + return nil +} + +type mockSigner struct { + key gossh.PublicKey +} + +func newMockSigner(id string) gossh.Signer { + return &mockSigner{key: &mockPublicKey{id: id}} +} + +func (s *mockSigner) PublicKey() gossh.PublicKey { + return s.key +} + +func (s *mockSigner) Sign(_ io.Reader, _ []byte) (*gossh.Signature, error) { + return &gossh.Signature{ + Format: "ssh-rsa", + Blob: []byte("sig"), + }, nil +} + func TestCollectKnownHostsAuthMethodsOrder(t *testing.T) { homeDir := "/tmp/dtail-auth-order" t.Setenv("HOME", homeDir) - originalPrivateKeyAuthMethod := privateKeyAuthMethod - originalAgentAuthMethod := agentAuthMethod + originalPrivateKeySigner := privateKeySigner + originalAgentSigners := agentSigners originalLogger := dlog.Client dlog.Client = &dlog.DLog{} t.Cleanup(func() { - privateKeyAuthMethod = originalPrivateKeyAuthMethod - agentAuthMethod = originalAgentAuthMethod + privateKeySigner = originalPrivateKeySigner + agentSigners = originalAgentSigners dlog.Client = originalLogger }) var callOrder []string - successfulPrivateKeys := map[string]bool{ - "/custom/id_fast": true, - homeDir + "/.ssh/id_rsa": true, - homeDir + "/.ssh/id_dsa": true, + successfulPrivateKeys := map[string]gossh.Signer{ + "/custom/id_fast": newMockSigner("custom"), + homeDir + "/.ssh/id_rsa": newMockSigner("default-rsa"), + homeDir + "/.ssh/id_dsa": newMockSigner("default-dsa"), } - privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) { + privateKeySigner = func(path string) (gossh.Signer, error) { callOrder = append(callOrder, "private:"+path) - if !successfulPrivateKeys[path] { + signer, found := successfulPrivateKeys[path] + if !found { return nil, fmt.Errorf("missing private key: %s", path) } - return gossh.Password(path), nil + return signer, nil } - agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) { + agentSigners = func(keyIndex int) ([]gossh.Signer, error) { callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex)) - return gossh.Password("agent"), nil + return []gossh.Signer{newMockSigner("agent")}, nil } methods := collectKnownHostsAuthMethods("/custom/id_fast", 7) - if len(methods) != 4 { - t.Fatalf("Expected 4 auth methods, got %d", len(methods)) + if len(methods) != 1 { + t.Fatalf("Expected 1 auth method, got %d", len(methods)) + } + + callOrder = nil + signers := collectKnownHostsSigners("/custom/id_fast", 7) + if len(signers) != 4 { + t.Fatalf("Expected 4 signers, got %d", len(signers)) } expectedOrder := []string{ @@ -65,32 +108,39 @@ func TestCollectKnownHostsAuthMethodsSkipsDuplicateDefaultPath(t *testing.T) { homeDir := "/tmp/dtail-auth-dedupe" t.Setenv("HOME", homeDir) - originalPrivateKeyAuthMethod := privateKeyAuthMethod - originalAgentAuthMethod := agentAuthMethod + originalPrivateKeySigner := privateKeySigner + originalAgentSigners := agentSigners originalLogger := dlog.Client dlog.Client = &dlog.DLog{} t.Cleanup(func() { - privateKeyAuthMethod = originalPrivateKeyAuthMethod - agentAuthMethod = originalAgentAuthMethod + privateKeySigner = originalPrivateKeySigner + agentSigners = originalAgentSigners dlog.Client = originalLogger }) + sharedSigner := newMockSigner("shared") var callOrder []string - privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) { + privateKeySigner = func(path string) (gossh.Signer, error) { callOrder = append(callOrder, "private:"+path) if path == homeDir+"/.ssh/id_rsa" { - return gossh.Password(path), nil + return sharedSigner, nil } return nil, fmt.Errorf("missing private key: %s", path) } - agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) { + agentSigners = func(keyIndex int) ([]gossh.Signer, error) { callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex)) - return gossh.Password("agent"), nil + return []gossh.Signer{sharedSigner}, nil } methods := collectKnownHostsAuthMethods(homeDir+"/.ssh/id_rsa", 2) - if len(methods) != 2 { - t.Fatalf("Expected 2 auth methods, got %d", len(methods)) + if len(methods) != 1 { + t.Fatalf("Expected 1 auth method, got %d", len(methods)) + } + + callOrder = nil + signers := collectKnownHostsSigners(homeDir+"/.ssh/id_rsa", 2) + if len(signers) != 1 { + t.Fatalf("Expected duplicate keys to collapse to 1 signer, got %d", len(signers)) } expectedOrder := []string{ diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 7088e89..a191fd5 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -48,9 +48,9 @@ func Agent() (gossh.AuthMethod, error) { return AgentWithKeyIndex(-1) } -// AgentWithKeyIndex used for SSH auth with a specific key index from the agent. +// AgentSignersWithKeyIndex returns SSH agent signers. // If keyIndex is -1, all keys are used. Otherwise, only the specified key is used. -func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) { +func AgentSignersWithKeyIndex(keyIndex int) ([]gossh.Signer, error) { // Use context-aware dialing for SSH agent connection (local Unix socket). // 2-second timeout is reasonable for local socket connections. dialer := &net.Dialer{ @@ -73,28 +73,33 @@ func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) { dlog.Common.Debug("Public key", i, key) } + signers, err := agentClient.Signers() + if err != nil { + return nil, fmt.Errorf("failed to load SSH agent signers: %w", err) + } + // If no specific key index requested, use all keys (backwards compatible default) if keyIndex < 0 { - return gossh.PublicKeysCallback(agentClient.Signers), nil + return signers, nil } // Use only the specified key index (0-based) - if keyIndex >= len(keys) { - return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys)) + if keyIndex >= len(signers) { + return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers)) } dlog.Common.Debug("Using SSH agent key at index", keyIndex) - return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) { - signers, err := agentClient.Signers() - if err != nil { - return nil, err - } - if keyIndex >= len(signers) { - return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers)) - } - // Return only the specified signer - return []gossh.Signer{signers[keyIndex]}, nil - }), nil + return []gossh.Signer{signers[keyIndex]}, nil +} + +// AgentWithKeyIndex used for SSH auth with a specific key index from the agent. +// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used. +func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) { + signers, err := AgentSignersWithKeyIndex(keyIndex) + if err != nil { + return nil, err + } + return gossh.PublicKeys(signers...), nil } // EnterKeyPhrase is required to read phrase protected private keys. @@ -108,8 +113,8 @@ func EnterKeyPhrase(keyFile string) []byte { return phrase } -// KeyFile returns the key as a SSH auth method. -func KeyFile(keyFile string) (gossh.AuthMethod, error) { +// PrivateKeySigner returns an SSH signer from the provided private key file. +func PrivateKeySigner(keyFile string) (gossh.Signer, error) { buffer, err := os.ReadFile(keyFile) if err != nil { return nil, err @@ -118,6 +123,15 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) { if err != nil { return nil, err } + return key, nil +} + +// KeyFile returns the key as a SSH auth method. +func KeyFile(keyFile string) (gossh.AuthMethod, error) { + key, err := PrivateKeySigner(keyFile) + if err != nil { + return nil, err + } // Key phrase support disabled as password will be printed to stdout! /* -- cgit v1.2.3