summaryrefslogtreecommitdiff
path: root/internal/ssh
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-03 10:15:02 +0200
committerPaul Buetow <paul@buetow.org>2026-03-03 10:15:02 +0200
commitf17ffe1bae2f176e4dda90ff4dd2cb267332a7b4 (patch)
tree2feecb488d4f90efb25b0dd7cbb2c5718926f3e2 /internal/ssh
parent7d3685a5ed4bfac85673793f8ae6d9c5a6cff962 (diff)
feat(ssh-client): collect auth methods in fallback order
Diffstat (limited to 'internal/ssh')
-rw-r--r--internal/ssh/client/authmethods.go93
-rw-r--r--internal/ssh/client/authmethods_test.go106
2 files changed, 159 insertions, 40 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index 1a4cb3f..a414ade 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -11,7 +11,10 @@ import (
gossh "golang.org/x/crypto/ssh"
)
-const addedPathStr string = "Added path to list of auth methods, not adding further methods"
+var (
+ privateKeyAuthMethod = ssh.PrivateKey
+ agentAuthMethod = ssh.AgentWithKeyIndex
+)
// InitSSHAuthMethods initialises all known SSH auth methods on the client side.
func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
@@ -39,14 +42,13 @@ func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod {
}
sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", addedPathStr, privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", privateKeyPath)
return sshAuthMethods
}
func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) {
- var sshAuthMethods []gossh.AuthMethod
knownHostsFile := fmt.Sprintf("%s/.ssh/known_hosts", os.Getenv("HOME"))
if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
// In case of integration test, override known hosts file path.
@@ -63,54 +65,65 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback
}
- // Try to read custom private key path.
- if privateKeyPath != "" {
- authMethod, err := ssh.PrivateKey(privateKeyPath)
- if err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", addedPathStr, privateKeyPath)
- return sshAuthMethods, knownHostsCallback
- }
- dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err)
+ sshAuthMethods := collectKnownHostsAuthMethods(privateKeyPath, agentKeyIndex)
+ if len(sshAuthMethods) == 0 {
+ dlog.Client.FatalPanic("Unable to find private SSH key information")
}
- // Second, try SSH Agent
- authMethod, err := ssh.AgentWithKeyIndex(agentKeyIndex)
- if err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+
- "to list of auth methods, not adding further methods")
- return sshAuthMethods, knownHostsCallback
+ return sshAuthMethods, knownHostsCallback
+}
+
+func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []gossh.AuthMethod {
+ var sshAuthMethods []gossh.AuthMethod
+
+ home := os.Getenv("HOME")
+ defaultPrivateKeyPaths := []string{
+ home + "/.ssh/id_rsa",
+ home + "/.ssh/id_dsa",
+ home + "/.ssh/id_ecdsa",
+ home + "/.ssh/id_ed25519",
}
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err)
- // Third, try Linux/UNIX default key paths
- privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa"
- authMethod, err = ssh.PrivateKey(privateKeyPath)
- if err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath)
- return sshAuthMethods, knownHostsCallback
+ if privateKeyPath == "" {
+ privateKeyPath = defaultPrivateKeyPaths[0]
}
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err)
- privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa"
- authMethod, err = ssh.PrivateKey(privateKeyPath)
- if err == nil {
+ addedPrivateKeyPaths := make(map[string]bool, len(defaultPrivateKeyPaths)+1)
+ addPrivateKeyAuthMethod := func(path string) {
+ if path == "" {
+ return
+ }
+ if addedPrivateKeyPaths[path] {
+ return
+ }
+
+ authMethod, err := privateKeyAuthMethod(path)
+ if err != nil {
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", path, err)
+ return
+ }
+
sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath)
- return sshAuthMethods, knownHostsCallback
+ addedPrivateKeyPaths[path] = true
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", path)
}
- privateKeyPath = os.Getenv("HOME") + "/.ssh/id_ecdsa"
- authMethod, err = ssh.PrivateKey(privateKeyPath)
+ // First, the explicit auth key path (or default ~/.ssh/id_rsa).
+ addPrivateKeyAuthMethod(privateKeyPath)
+
+ // Second, SSH agent (YubiKey-backed keys are typically exposed here).
+ authMethod, err := agentAuthMethod(agentKeyIndex)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath)
- return sshAuthMethods, knownHostsCallback
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH agent auth method")
+ } else {
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err)
}
- dlog.Client.FatalPanic("Unable to find private SSH key information", privateKeyPath, err)
- // Never reach this point.
- return sshAuthMethods, knownHostsCallback
+ // Third, additional default private key paths.
+ for _, path := range defaultPrivateKeyPaths {
+ addPrivateKeyAuthMethod(path)
+ }
+
+ return sshAuthMethods
}
diff --git a/internal/ssh/client/authmethods_test.go b/internal/ssh/client/authmethods_test.go
new file mode 100644
index 0000000..04751f5
--- /dev/null
+++ b/internal/ssh/client/authmethods_test.go
@@ -0,0 +1,106 @@
+package client
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/mimecast/dtail/internal/io/dlog"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+func TestCollectKnownHostsAuthMethodsOrder(t *testing.T) {
+ homeDir := "/tmp/dtail-auth-order"
+ t.Setenv("HOME", homeDir)
+
+ originalPrivateKeyAuthMethod := privateKeyAuthMethod
+ originalAgentAuthMethod := agentAuthMethod
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ privateKeyAuthMethod = originalPrivateKeyAuthMethod
+ agentAuthMethod = originalAgentAuthMethod
+ dlog.Client = originalLogger
+ })
+
+ var callOrder []string
+ successfulPrivateKeys := map[string]bool{
+ "/custom/id_fast": true,
+ homeDir + "/.ssh/id_rsa": true,
+ homeDir + "/.ssh/id_dsa": true,
+ }
+
+ privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
+ callOrder = append(callOrder, "private:"+path)
+ if !successfulPrivateKeys[path] {
+ return nil, fmt.Errorf("missing private key: %s", path)
+ }
+ return gossh.Password(path), nil
+ }
+ agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
+ callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
+ return gossh.Password("agent"), nil
+ }
+
+ methods := collectKnownHostsAuthMethods("/custom/id_fast", 7)
+ if len(methods) != 4 {
+ t.Fatalf("Expected 4 auth methods, got %d", len(methods))
+ }
+
+ expectedOrder := []string{
+ "private:/custom/id_fast",
+ "agent:7",
+ "private:/tmp/dtail-auth-order/.ssh/id_rsa",
+ "private:/tmp/dtail-auth-order/.ssh/id_dsa",
+ "private:/tmp/dtail-auth-order/.ssh/id_ecdsa",
+ "private:/tmp/dtail-auth-order/.ssh/id_ed25519",
+ }
+ if !reflect.DeepEqual(callOrder, expectedOrder) {
+ t.Fatalf("Unexpected auth method call order.\nexpected: %v\ngot: %v", expectedOrder, callOrder)
+ }
+}
+
+func TestCollectKnownHostsAuthMethodsSkipsDuplicateDefaultPath(t *testing.T) {
+ homeDir := "/tmp/dtail-auth-dedupe"
+ t.Setenv("HOME", homeDir)
+
+ originalPrivateKeyAuthMethod := privateKeyAuthMethod
+ originalAgentAuthMethod := agentAuthMethod
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ privateKeyAuthMethod = originalPrivateKeyAuthMethod
+ agentAuthMethod = originalAgentAuthMethod
+ dlog.Client = originalLogger
+ })
+
+ var callOrder []string
+ privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
+ callOrder = append(callOrder, "private:"+path)
+ if path == homeDir+"/.ssh/id_rsa" {
+ return gossh.Password(path), nil
+ }
+ return nil, fmt.Errorf("missing private key: %s", path)
+ }
+ agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
+ callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
+ return gossh.Password("agent"), nil
+ }
+
+ methods := collectKnownHostsAuthMethods(homeDir+"/.ssh/id_rsa", 2)
+ if len(methods) != 2 {
+ t.Fatalf("Expected 2 auth methods, got %d", len(methods))
+ }
+
+ expectedOrder := []string{
+ "private:/tmp/dtail-auth-dedupe/.ssh/id_rsa",
+ "agent:2",
+ "private:/tmp/dtail-auth-dedupe/.ssh/id_dsa",
+ "private:/tmp/dtail-auth-dedupe/.ssh/id_ecdsa",
+ "private:/tmp/dtail-auth-dedupe/.ssh/id_ed25519",
+ }
+ if !reflect.DeepEqual(callOrder, expectedOrder) {
+ t.Fatalf("Unexpected auth method call order.\nexpected: %v\ngot: %v", expectedOrder, callOrder)
+ }
+}