diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-03 10:15:02 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-03 10:15:02 +0200 |
| commit | f17ffe1bae2f176e4dda90ff4dd2cb267332a7b4 (patch) | |
| tree | 2feecb488d4f90efb25b0dd7cbb2c5718926f3e2 /internal/ssh | |
| parent | 7d3685a5ed4bfac85673793f8ae6d9c5a6cff962 (diff) | |
feat(ssh-client): collect auth methods in fallback order
Diffstat (limited to 'internal/ssh')
| -rw-r--r-- | internal/ssh/client/authmethods.go | 93 | ||||
| -rw-r--r-- | internal/ssh/client/authmethods_test.go | 106 |
2 files changed, 159 insertions, 40 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 1a4cb3f..a414ade 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -11,7 +11,10 @@ import ( gossh "golang.org/x/crypto/ssh" ) -const addedPathStr string = "Added path to list of auth methods, not adding further methods" +var ( + privateKeyAuthMethod = ssh.PrivateKey + agentAuthMethod = ssh.AgentWithKeyIndex +) // InitSSHAuthMethods initialises all known SSH auth methods on the client side. func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, @@ -39,14 +42,13 @@ func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { } sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", addedPathStr, privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", privateKeyPath) return sshAuthMethods } func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { - var sshAuthMethods []gossh.AuthMethod knownHostsFile := fmt.Sprintf("%s/.ssh/known_hosts", os.Getenv("HOME")) if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { // In case of integration test, override known hosts file path. @@ -63,54 +65,65 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback } - // Try to read custom private key path. - if privateKeyPath != "" { - authMethod, err := ssh.PrivateKey(privateKeyPath) - if err == nil { - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", addedPathStr, privateKeyPath) - return sshAuthMethods, knownHostsCallback - } - dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err) + sshAuthMethods := collectKnownHostsAuthMethods(privateKeyPath, agentKeyIndex) + if len(sshAuthMethods) == 0 { + dlog.Client.FatalPanic("Unable to find private SSH key information") } - // Second, try SSH Agent - authMethod, err := ssh.AgentWithKeyIndex(agentKeyIndex) - if err == nil { - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+ - "to list of auth methods, not adding further methods") - return sshAuthMethods, knownHostsCallback + return sshAuthMethods, knownHostsCallback +} + +func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []gossh.AuthMethod { + var sshAuthMethods []gossh.AuthMethod + + home := os.Getenv("HOME") + defaultPrivateKeyPaths := []string{ + home + "/.ssh/id_rsa", + home + "/.ssh/id_dsa", + home + "/.ssh/id_ecdsa", + home + "/.ssh/id_ed25519", } - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) - // Third, try Linux/UNIX default key paths - privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa" - authMethod, err = ssh.PrivateKey(privateKeyPath) - if err == nil { - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath) - return sshAuthMethods, knownHostsCallback + if privateKeyPath == "" { + privateKeyPath = defaultPrivateKeyPaths[0] } - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) - privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa" - authMethod, err = ssh.PrivateKey(privateKeyPath) - if err == nil { + addedPrivateKeyPaths := make(map[string]bool, len(defaultPrivateKeyPaths)+1) + addPrivateKeyAuthMethod := func(path string) { + if path == "" { + return + } + if addedPrivateKeyPaths[path] { + return + } + + authMethod, err := privateKeyAuthMethod(path) + if err != nil { + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", path, err) + return + } + sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath) - return sshAuthMethods, knownHostsCallback + addedPrivateKeyPaths[path] = true + dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", path) } - privateKeyPath = os.Getenv("HOME") + "/.ssh/id_ecdsa" - authMethod, err = ssh.PrivateKey(privateKeyPath) + // First, the explicit auth key path (or default ~/.ssh/id_rsa). + addPrivateKeyAuthMethod(privateKeyPath) + + // Second, SSH agent (YubiKey-backed keys are typically exposed here). + authMethod, err := agentAuthMethod(agentKeyIndex) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath) - return sshAuthMethods, knownHostsCallback + dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH agent auth method") + } else { + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) } - dlog.Client.FatalPanic("Unable to find private SSH key information", privateKeyPath, err) - // Never reach this point. - return sshAuthMethods, knownHostsCallback + // Third, additional default private key paths. + for _, path := range defaultPrivateKeyPaths { + addPrivateKeyAuthMethod(path) + } + + return sshAuthMethods } diff --git a/internal/ssh/client/authmethods_test.go b/internal/ssh/client/authmethods_test.go new file mode 100644 index 0000000..04751f5 --- /dev/null +++ b/internal/ssh/client/authmethods_test.go @@ -0,0 +1,106 @@ +package client + +import ( + "fmt" + "reflect" + "testing" + + "github.com/mimecast/dtail/internal/io/dlog" + + gossh "golang.org/x/crypto/ssh" +) + +func TestCollectKnownHostsAuthMethodsOrder(t *testing.T) { + homeDir := "/tmp/dtail-auth-order" + t.Setenv("HOME", homeDir) + + originalPrivateKeyAuthMethod := privateKeyAuthMethod + originalAgentAuthMethod := agentAuthMethod + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + privateKeyAuthMethod = originalPrivateKeyAuthMethod + agentAuthMethod = originalAgentAuthMethod + dlog.Client = originalLogger + }) + + var callOrder []string + successfulPrivateKeys := map[string]bool{ + "/custom/id_fast": true, + homeDir + "/.ssh/id_rsa": true, + homeDir + "/.ssh/id_dsa": true, + } + + privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) { + callOrder = append(callOrder, "private:"+path) + if !successfulPrivateKeys[path] { + return nil, fmt.Errorf("missing private key: %s", path) + } + return gossh.Password(path), nil + } + agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) { + callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex)) + return gossh.Password("agent"), nil + } + + methods := collectKnownHostsAuthMethods("/custom/id_fast", 7) + if len(methods) != 4 { + t.Fatalf("Expected 4 auth methods, got %d", len(methods)) + } + + expectedOrder := []string{ + "private:/custom/id_fast", + "agent:7", + "private:/tmp/dtail-auth-order/.ssh/id_rsa", + "private:/tmp/dtail-auth-order/.ssh/id_dsa", + "private:/tmp/dtail-auth-order/.ssh/id_ecdsa", + "private:/tmp/dtail-auth-order/.ssh/id_ed25519", + } + if !reflect.DeepEqual(callOrder, expectedOrder) { + t.Fatalf("Unexpected auth method call order.\nexpected: %v\ngot: %v", expectedOrder, callOrder) + } +} + +func TestCollectKnownHostsAuthMethodsSkipsDuplicateDefaultPath(t *testing.T) { + homeDir := "/tmp/dtail-auth-dedupe" + t.Setenv("HOME", homeDir) + + originalPrivateKeyAuthMethod := privateKeyAuthMethod + originalAgentAuthMethod := agentAuthMethod + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + privateKeyAuthMethod = originalPrivateKeyAuthMethod + agentAuthMethod = originalAgentAuthMethod + dlog.Client = originalLogger + }) + + var callOrder []string + privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) { + callOrder = append(callOrder, "private:"+path) + if path == homeDir+"/.ssh/id_rsa" { + return gossh.Password(path), nil + } + return nil, fmt.Errorf("missing private key: %s", path) + } + agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) { + callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex)) + return gossh.Password("agent"), nil + } + + methods := collectKnownHostsAuthMethods(homeDir+"/.ssh/id_rsa", 2) + if len(methods) != 2 { + t.Fatalf("Expected 2 auth methods, got %d", len(methods)) + } + + expectedOrder := []string{ + "private:/tmp/dtail-auth-dedupe/.ssh/id_rsa", + "agent:2", + "private:/tmp/dtail-auth-dedupe/.ssh/id_dsa", + "private:/tmp/dtail-auth-dedupe/.ssh/id_ecdsa", + "private:/tmp/dtail-auth-dedupe/.ssh/id_ed25519", + } + if !reflect.DeepEqual(callOrder, expectedOrder) { + t.Fatalf("Unexpected auth method call order.\nexpected: %v\ngot: %v", expectedOrder, callOrder) + } +} |
