diff options
Diffstat (limited to 'internal/ssh/client/authmethods.go')
| -rw-r--r-- | internal/ssh/client/authmethods.go | 77 |
1 files changed, 46 insertions, 31 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index a414ade..7ac4d0c 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -12,8 +12,8 @@ import ( ) var ( - privateKeyAuthMethod = ssh.PrivateKey - agentAuthMethod = ssh.AgentWithKeyIndex + privateKeySigner = ssh.PrivateKeySigner + agentSigners = ssh.AgentSignersWithKeyIndex ) // InitSSHAuthMethods initialises all known SSH auth methods on the client side. @@ -31,21 +31,6 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex) } -func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { - var sshAuthMethods []gossh.AuthMethod - privateKeyPath := "./id_rsa" - - GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096) - authMethod, err := ssh.PrivateKey(privateKeyPath) - if err != nil { - dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err) - } - - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", privateKeyPath) - return sshAuthMethods -} - func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { @@ -62,7 +47,10 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsFile) if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { - return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback + if privateKeyPath == "" { + privateKeyPath = "./id_rsa" + } + GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096) } sshAuthMethods := collectKnownHostsAuthMethods(privateKeyPath, agentKeyIndex) @@ -74,7 +62,15 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, } func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []gossh.AuthMethod { - var sshAuthMethods []gossh.AuthMethod + signers := collectKnownHostsSigners(privateKeyPath, agentKeyIndex) + if len(signers) == 0 { + return nil + } + return []gossh.AuthMethod{gossh.PublicKeys(signers...)} +} + +func collectKnownHostsSigners(privateKeyPath string, agentKeyIndex int) []gossh.Signer { + var signers []gossh.Signer home := os.Getenv("HOME") defaultPrivateKeyPaths := []string{ @@ -83,13 +79,32 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go home + "/.ssh/id_ecdsa", home + "/.ssh/id_ed25519", } + if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { + defaultPrivateKeyPaths = append([]string{"./id_rsa"}, defaultPrivateKeyPaths...) + } if privateKeyPath == "" { privateKeyPath = defaultPrivateKeyPaths[0] } addedPrivateKeyPaths := make(map[string]bool, len(defaultPrivateKeyPaths)+1) - addPrivateKeyAuthMethod := func(path string) { + addedPublicKeys := make(map[string]bool, len(defaultPrivateKeyPaths)+1) + addSigner := func(source string, signer gossh.Signer) { + if signer == nil { + return + } + + pubKey := string(signer.PublicKey().Marshal()) + if addedPublicKeys[pubKey] { + dlog.Client.Debug("initKnownHostsAuthMethods", "Skipping duplicate signer", source) + return + } + + addedPublicKeys[pubKey] = true + signers = append(signers, signer) + dlog.Client.Debug("initKnownHostsAuthMethods", "Added signer", source) + } + addPrivateKeySigner := func(path string) { if path == "" { return } @@ -97,33 +112,33 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go return } - authMethod, err := privateKeyAuthMethod(path) + signer, err := privateKeySigner(path) if err != nil { - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", path, err) + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load private key signer", path, err) return } - sshAuthMethods = append(sshAuthMethods, authMethod) addedPrivateKeyPaths[path] = true - dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", path) + addSigner(path, signer) } // First, the explicit auth key path (or default ~/.ssh/id_rsa). - addPrivateKeyAuthMethod(privateKeyPath) + addPrivateKeySigner(privateKeyPath) // Second, SSH agent (YubiKey-backed keys are typically exposed here). - authMethod, err := agentAuthMethod(agentKeyIndex) + loadedAgentSigners, err := agentSigners(agentKeyIndex) if err == nil { - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH agent auth method") + for i, signer := range loadedAgentSigners { + addSigner(fmt.Sprintf("agent:%d:%d", agentKeyIndex, i), signer) + } } else { - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load SSH agent signers", err) } // Third, additional default private key paths. for _, path := range defaultPrivateKeyPaths { - addPrivateKeyAuthMethod(path) + addPrivateKeySigner(path) } - return sshAuthMethods + return signers } |
