summaryrefslogtreecommitdiff
path: root/internal/ssh/client/authmethods.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ssh/client/authmethods.go')
-rw-r--r--internal/ssh/client/authmethods.go77
1 files changed, 46 insertions, 31 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index a414ade..7ac4d0c 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -12,8 +12,8 @@ import (
)
var (
- privateKeyAuthMethod = ssh.PrivateKey
- agentAuthMethod = ssh.AgentWithKeyIndex
+ privateKeySigner = ssh.PrivateKeySigner
+ agentSigners = ssh.AgentSignersWithKeyIndex
)
// InitSSHAuthMethods initialises all known SSH auth methods on the client side.
@@ -31,21 +31,6 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex)
}
-func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod {
- var sshAuthMethods []gossh.AuthMethod
- privateKeyPath := "./id_rsa"
-
- GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096)
- authMethod, err := ssh.PrivateKey(privateKeyPath)
- if err != nil {
- dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err)
- }
-
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", privateKeyPath)
- return sshAuthMethods
-}
-
func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) {
@@ -62,7 +47,10 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsFile)
if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
- return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback
+ if privateKeyPath == "" {
+ privateKeyPath = "./id_rsa"
+ }
+ GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096)
}
sshAuthMethods := collectKnownHostsAuthMethods(privateKeyPath, agentKeyIndex)
@@ -74,7 +62,15 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
}
func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []gossh.AuthMethod {
- var sshAuthMethods []gossh.AuthMethod
+ signers := collectKnownHostsSigners(privateKeyPath, agentKeyIndex)
+ if len(signers) == 0 {
+ return nil
+ }
+ return []gossh.AuthMethod{gossh.PublicKeys(signers...)}
+}
+
+func collectKnownHostsSigners(privateKeyPath string, agentKeyIndex int) []gossh.Signer {
+ var signers []gossh.Signer
home := os.Getenv("HOME")
defaultPrivateKeyPaths := []string{
@@ -83,13 +79,32 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go
home + "/.ssh/id_ecdsa",
home + "/.ssh/id_ed25519",
}
+ if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
+ defaultPrivateKeyPaths = append([]string{"./id_rsa"}, defaultPrivateKeyPaths...)
+ }
if privateKeyPath == "" {
privateKeyPath = defaultPrivateKeyPaths[0]
}
addedPrivateKeyPaths := make(map[string]bool, len(defaultPrivateKeyPaths)+1)
- addPrivateKeyAuthMethod := func(path string) {
+ addedPublicKeys := make(map[string]bool, len(defaultPrivateKeyPaths)+1)
+ addSigner := func(source string, signer gossh.Signer) {
+ if signer == nil {
+ return
+ }
+
+ pubKey := string(signer.PublicKey().Marshal())
+ if addedPublicKeys[pubKey] {
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Skipping duplicate signer", source)
+ return
+ }
+
+ addedPublicKeys[pubKey] = true
+ signers = append(signers, signer)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added signer", source)
+ }
+ addPrivateKeySigner := func(path string) {
if path == "" {
return
}
@@ -97,33 +112,33 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go
return
}
- authMethod, err := privateKeyAuthMethod(path)
+ signer, err := privateKeySigner(path)
if err != nil {
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", path, err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load private key signer", path, err)
return
}
- sshAuthMethods = append(sshAuthMethods, authMethod)
addedPrivateKeyPaths[path] = true
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", path)
+ addSigner(path, signer)
}
// First, the explicit auth key path (or default ~/.ssh/id_rsa).
- addPrivateKeyAuthMethod(privateKeyPath)
+ addPrivateKeySigner(privateKeyPath)
// Second, SSH agent (YubiKey-backed keys are typically exposed here).
- authMethod, err := agentAuthMethod(agentKeyIndex)
+ loadedAgentSigners, err := agentSigners(agentKeyIndex)
if err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH agent auth method")
+ for i, signer := range loadedAgentSigners {
+ addSigner(fmt.Sprintf("agent:%d:%d", agentKeyIndex, i), signer)
+ }
} else {
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load SSH agent signers", err)
}
// Third, additional default private key paths.
for _, path := range defaultPrivateKeyPaths {
- addPrivateKeyAuthMethod(path)
+ addPrivateKeySigner(path)
}
- return sshAuthMethods
+ return signers
}