summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/connectors/serverconnection.go13
-rw-r--r--internal/server/handlers/serverhandler.go11
-rw-r--r--internal/ssh/client/authmethods.go77
-rw-r--r--internal/ssh/client/authmethods_test.go100
-rw-r--r--internal/ssh/ssh.go50
5 files changed, 169 insertions, 82 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index fbeb1bc..649fe30 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -224,18 +224,19 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc
}
}()
- // Send all commands to client.
+ if c.authKeyDisabled {
+ dlog.Client.Debug(c.server, "Skipping AUTHKEY registration because auth-key is disabled")
+ } else {
+ c.sendAuthKeyRegistrationCommand()
+ }
+
+ // Send all requested commands to the server.
for _, command := range c.commands {
dlog.Client.Debug(command)
if err := c.handler.SendMessage(command); err != nil {
dlog.Client.Debug(err)
}
}
- if c.authKeyDisabled {
- dlog.Client.Debug(c.server, "Skipping AUTHKEY registration because auth-key is disabled")
- } else {
- c.sendAuthKeyRegistrationCommand()
- }
if !c.throttlingDone {
dlog.Client.Debug(c.server, "Unthrottling connection (2)",
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 53ab4e3..5d5a78c 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -79,14 +79,17 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon
argc int, args []string, commandName string) {
dlog.Server.Debug(h.user, "Handling user command", argc, args)
+ shutdownOnCompletion := shouldShutdownOnCommandCompletion(commandName)
h.incrementActiveCommands()
commandFinished := func() {
activeCommands := h.decrementActiveCommands()
pendingFiles := atomic.LoadInt32(&h.pendingFiles)
dlog.Server.Debug(h.user, "Command finished", "activeCommands", activeCommands, "pendingFiles", pendingFiles)
- // Only shutdown if no active commands AND no pending files
- if activeCommands == 0 && pendingFiles == 0 {
+ // Only shutdown if no active commands AND no pending files.
+ // AUTHKEY is a session-side effect command and should not terminate the shell
+ // because user commands may still follow in the same session.
+ if shutdownOnCompletion && activeCommands == 0 && pendingFiles == 0 {
h.shutdown()
}
}
@@ -102,6 +105,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon
handler(ctx, ltx, argc, args, commandFinished)
}
+func shouldShutdownOnCommandCompletion(commandName string) bool {
+ return !strings.EqualFold(commandName, "AUTHKEY")
+}
+
func (h *ServerHandler) newCommandRegistry() map[string]commandHandler {
return map[string]commandHandler{
"grep": h.makeReadCommandHandler(omode.GrepClient, 1),
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index a414ade..7ac4d0c 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -12,8 +12,8 @@ import (
)
var (
- privateKeyAuthMethod = ssh.PrivateKey
- agentAuthMethod = ssh.AgentWithKeyIndex
+ privateKeySigner = ssh.PrivateKeySigner
+ agentSigners = ssh.AgentSignersWithKeyIndex
)
// InitSSHAuthMethods initialises all known SSH auth methods on the client side.
@@ -31,21 +31,6 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex)
}
-func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod {
- var sshAuthMethods []gossh.AuthMethod
- privateKeyPath := "./id_rsa"
-
- GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096)
- authMethod, err := ssh.PrivateKey(privateKeyPath)
- if err != nil {
- dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err)
- }
-
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", privateKeyPath)
- return sshAuthMethods
-}
-
func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) {
@@ -62,7 +47,10 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsFile)
if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
- return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback
+ if privateKeyPath == "" {
+ privateKeyPath = "./id_rsa"
+ }
+ GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096)
}
sshAuthMethods := collectKnownHostsAuthMethods(privateKeyPath, agentKeyIndex)
@@ -74,7 +62,15 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
}
func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []gossh.AuthMethod {
- var sshAuthMethods []gossh.AuthMethod
+ signers := collectKnownHostsSigners(privateKeyPath, agentKeyIndex)
+ if len(signers) == 0 {
+ return nil
+ }
+ return []gossh.AuthMethod{gossh.PublicKeys(signers...)}
+}
+
+func collectKnownHostsSigners(privateKeyPath string, agentKeyIndex int) []gossh.Signer {
+ var signers []gossh.Signer
home := os.Getenv("HOME")
defaultPrivateKeyPaths := []string{
@@ -83,13 +79,32 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go
home + "/.ssh/id_ecdsa",
home + "/.ssh/id_ed25519",
}
+ if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
+ defaultPrivateKeyPaths = append([]string{"./id_rsa"}, defaultPrivateKeyPaths...)
+ }
if privateKeyPath == "" {
privateKeyPath = defaultPrivateKeyPaths[0]
}
addedPrivateKeyPaths := make(map[string]bool, len(defaultPrivateKeyPaths)+1)
- addPrivateKeyAuthMethod := func(path string) {
+ addedPublicKeys := make(map[string]bool, len(defaultPrivateKeyPaths)+1)
+ addSigner := func(source string, signer gossh.Signer) {
+ if signer == nil {
+ return
+ }
+
+ pubKey := string(signer.PublicKey().Marshal())
+ if addedPublicKeys[pubKey] {
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Skipping duplicate signer", source)
+ return
+ }
+
+ addedPublicKeys[pubKey] = true
+ signers = append(signers, signer)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added signer", source)
+ }
+ addPrivateKeySigner := func(path string) {
if path == "" {
return
}
@@ -97,33 +112,33 @@ func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []go
return
}
- authMethod, err := privateKeyAuthMethod(path)
+ signer, err := privateKeySigner(path)
if err != nil {
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", path, err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load private key signer", path, err)
return
}
- sshAuthMethods = append(sshAuthMethods, authMethod)
addedPrivateKeyPaths[path] = true
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", path)
+ addSigner(path, signer)
}
// First, the explicit auth key path (or default ~/.ssh/id_rsa).
- addPrivateKeyAuthMethod(privateKeyPath)
+ addPrivateKeySigner(privateKeyPath)
// Second, SSH agent (YubiKey-backed keys are typically exposed here).
- authMethod, err := agentAuthMethod(agentKeyIndex)
+ loadedAgentSigners, err := agentSigners(agentKeyIndex)
if err == nil {
- sshAuthMethods = append(sshAuthMethods, authMethod)
- dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH agent auth method")
+ for i, signer := range loadedAgentSigners {
+ addSigner(fmt.Sprintf("agent:%d:%d", agentKeyIndex, i), signer)
+ }
} else {
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to load SSH agent signers", err)
}
// Third, additional default private key paths.
for _, path := range defaultPrivateKeyPaths {
- addPrivateKeyAuthMethod(path)
+ addPrivateKeySigner(path)
}
- return sshAuthMethods
+ return signers
}
diff --git a/internal/ssh/client/authmethods_test.go b/internal/ssh/client/authmethods_test.go
index 04751f5..e1e92b0 100644
--- a/internal/ssh/client/authmethods_test.go
+++ b/internal/ssh/client/authmethods_test.go
@@ -2,6 +2,7 @@ package client
import (
"fmt"
+ "io"
"reflect"
"testing"
@@ -10,42 +11,84 @@ import (
gossh "golang.org/x/crypto/ssh"
)
+type mockPublicKey struct {
+ id string
+}
+
+func (k *mockPublicKey) Type() string {
+ return "ssh-rsa"
+}
+
+func (k *mockPublicKey) Marshal() []byte {
+ return []byte(k.id)
+}
+
+func (k *mockPublicKey) Verify(_ []byte, _ *gossh.Signature) error {
+ return nil
+}
+
+type mockSigner struct {
+ key gossh.PublicKey
+}
+
+func newMockSigner(id string) gossh.Signer {
+ return &mockSigner{key: &mockPublicKey{id: id}}
+}
+
+func (s *mockSigner) PublicKey() gossh.PublicKey {
+ return s.key
+}
+
+func (s *mockSigner) Sign(_ io.Reader, _ []byte) (*gossh.Signature, error) {
+ return &gossh.Signature{
+ Format: "ssh-rsa",
+ Blob: []byte("sig"),
+ }, nil
+}
+
func TestCollectKnownHostsAuthMethodsOrder(t *testing.T) {
homeDir := "/tmp/dtail-auth-order"
t.Setenv("HOME", homeDir)
- originalPrivateKeyAuthMethod := privateKeyAuthMethod
- originalAgentAuthMethod := agentAuthMethod
+ originalPrivateKeySigner := privateKeySigner
+ originalAgentSigners := agentSigners
originalLogger := dlog.Client
dlog.Client = &dlog.DLog{}
t.Cleanup(func() {
- privateKeyAuthMethod = originalPrivateKeyAuthMethod
- agentAuthMethod = originalAgentAuthMethod
+ privateKeySigner = originalPrivateKeySigner
+ agentSigners = originalAgentSigners
dlog.Client = originalLogger
})
var callOrder []string
- successfulPrivateKeys := map[string]bool{
- "/custom/id_fast": true,
- homeDir + "/.ssh/id_rsa": true,
- homeDir + "/.ssh/id_dsa": true,
+ successfulPrivateKeys := map[string]gossh.Signer{
+ "/custom/id_fast": newMockSigner("custom"),
+ homeDir + "/.ssh/id_rsa": newMockSigner("default-rsa"),
+ homeDir + "/.ssh/id_dsa": newMockSigner("default-dsa"),
}
- privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
+ privateKeySigner = func(path string) (gossh.Signer, error) {
callOrder = append(callOrder, "private:"+path)
- if !successfulPrivateKeys[path] {
+ signer, found := successfulPrivateKeys[path]
+ if !found {
return nil, fmt.Errorf("missing private key: %s", path)
}
- return gossh.Password(path), nil
+ return signer, nil
}
- agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
+ agentSigners = func(keyIndex int) ([]gossh.Signer, error) {
callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
- return gossh.Password("agent"), nil
+ return []gossh.Signer{newMockSigner("agent")}, nil
}
methods := collectKnownHostsAuthMethods("/custom/id_fast", 7)
- if len(methods) != 4 {
- t.Fatalf("Expected 4 auth methods, got %d", len(methods))
+ if len(methods) != 1 {
+ t.Fatalf("Expected 1 auth method, got %d", len(methods))
+ }
+
+ callOrder = nil
+ signers := collectKnownHostsSigners("/custom/id_fast", 7)
+ if len(signers) != 4 {
+ t.Fatalf("Expected 4 signers, got %d", len(signers))
}
expectedOrder := []string{
@@ -65,32 +108,39 @@ func TestCollectKnownHostsAuthMethodsSkipsDuplicateDefaultPath(t *testing.T) {
homeDir := "/tmp/dtail-auth-dedupe"
t.Setenv("HOME", homeDir)
- originalPrivateKeyAuthMethod := privateKeyAuthMethod
- originalAgentAuthMethod := agentAuthMethod
+ originalPrivateKeySigner := privateKeySigner
+ originalAgentSigners := agentSigners
originalLogger := dlog.Client
dlog.Client = &dlog.DLog{}
t.Cleanup(func() {
- privateKeyAuthMethod = originalPrivateKeyAuthMethod
- agentAuthMethod = originalAgentAuthMethod
+ privateKeySigner = originalPrivateKeySigner
+ agentSigners = originalAgentSigners
dlog.Client = originalLogger
})
+ sharedSigner := newMockSigner("shared")
var callOrder []string
- privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
+ privateKeySigner = func(path string) (gossh.Signer, error) {
callOrder = append(callOrder, "private:"+path)
if path == homeDir+"/.ssh/id_rsa" {
- return gossh.Password(path), nil
+ return sharedSigner, nil
}
return nil, fmt.Errorf("missing private key: %s", path)
}
- agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
+ agentSigners = func(keyIndex int) ([]gossh.Signer, error) {
callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
- return gossh.Password("agent"), nil
+ return []gossh.Signer{sharedSigner}, nil
}
methods := collectKnownHostsAuthMethods(homeDir+"/.ssh/id_rsa", 2)
- if len(methods) != 2 {
- t.Fatalf("Expected 2 auth methods, got %d", len(methods))
+ if len(methods) != 1 {
+ t.Fatalf("Expected 1 auth method, got %d", len(methods))
+ }
+
+ callOrder = nil
+ signers := collectKnownHostsSigners(homeDir+"/.ssh/id_rsa", 2)
+ if len(signers) != 1 {
+ t.Fatalf("Expected duplicate keys to collapse to 1 signer, got %d", len(signers))
}
expectedOrder := []string{
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 7088e89..a191fd5 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -48,9 +48,9 @@ func Agent() (gossh.AuthMethod, error) {
return AgentWithKeyIndex(-1)
}
-// AgentWithKeyIndex used for SSH auth with a specific key index from the agent.
+// AgentSignersWithKeyIndex returns SSH agent signers.
// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used.
-func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
+func AgentSignersWithKeyIndex(keyIndex int) ([]gossh.Signer, error) {
// Use context-aware dialing for SSH agent connection (local Unix socket).
// 2-second timeout is reasonable for local socket connections.
dialer := &net.Dialer{
@@ -73,28 +73,33 @@ func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
dlog.Common.Debug("Public key", i, key)
}
+ signers, err := agentClient.Signers()
+ if err != nil {
+ return nil, fmt.Errorf("failed to load SSH agent signers: %w", err)
+ }
+
// If no specific key index requested, use all keys (backwards compatible default)
if keyIndex < 0 {
- return gossh.PublicKeysCallback(agentClient.Signers), nil
+ return signers, nil
}
// Use only the specified key index (0-based)
- if keyIndex >= len(keys) {
- return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys))
+ if keyIndex >= len(signers) {
+ return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers))
}
dlog.Common.Debug("Using SSH agent key at index", keyIndex)
- return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) {
- signers, err := agentClient.Signers()
- if err != nil {
- return nil, err
- }
- if keyIndex >= len(signers) {
- return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers))
- }
- // Return only the specified signer
- return []gossh.Signer{signers[keyIndex]}, nil
- }), nil
+ return []gossh.Signer{signers[keyIndex]}, nil
+}
+
+// AgentWithKeyIndex used for SSH auth with a specific key index from the agent.
+// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used.
+func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
+ signers, err := AgentSignersWithKeyIndex(keyIndex)
+ if err != nil {
+ return nil, err
+ }
+ return gossh.PublicKeys(signers...), nil
}
// EnterKeyPhrase is required to read phrase protected private keys.
@@ -108,8 +113,8 @@ func EnterKeyPhrase(keyFile string) []byte {
return phrase
}
-// KeyFile returns the key as a SSH auth method.
-func KeyFile(keyFile string) (gossh.AuthMethod, error) {
+// PrivateKeySigner returns an SSH signer from the provided private key file.
+func PrivateKeySigner(keyFile string) (gossh.Signer, error) {
buffer, err := os.ReadFile(keyFile)
if err != nil {
return nil, err
@@ -118,6 +123,15 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) {
if err != nil {
return nil, err
}
+ return key, nil
+}
+
+// KeyFile returns the key as a SSH auth method.
+func KeyFile(keyFile string) (gossh.AuthMethod, error) {
+ key, err := PrivateKeySigner(keyFile)
+ if err != nil {
+ return nil, err
+ }
// Key phrase support disabled as password will be printed to stdout!
/*