diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-03 10:20:05 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-03 10:20:05 +0200 |
| commit | 6d50a475114699911f2ebe1376915cd8317f1881 (patch) | |
| tree | 09d742d296024d028e1b35e45cb93baaf95a15d7 /internal | |
| parent | f17ffe1bae2f176e4dda90ff4dd2cb267332a7b4 (diff) | |
feat(client): register AUTHKEY after SSH session start
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/clients/baseclient.go | 3 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection.go | 59 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection_test.go | 112 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler.go | 45 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler_test.go | 59 |
5 files changed, 276 insertions, 2 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 95fa721..766f05d 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -203,5 +203,6 @@ func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMe c.maker.makeCommands()) } return connectors.NewServerConnection(server, c.UserName, sshAuthMethods, - hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands()) + hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands(), + c.Args.SSHPrivateKeyFilePath) } diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 3c29ac0..ca1fc43 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -2,9 +2,11 @@ package connectors import ( "context" + "encoding/base64" "fmt" "io" "net" + "os" "strconv" "strings" "time" @@ -29,6 +31,7 @@ type ServerConnection struct { config *ssh.ClientConfig handler handlers.Handler commands []string + authKeyPath string hostKeyCallback client.HostKeyCallback throttlingDone bool } @@ -38,7 +41,7 @@ var _ Connector = (*ServerConnection)(nil) // NewServerConnection returns a new DTail SSH server connection. func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, - handler handlers.Handler, commands []string) *ServerConnection { + handler handlers.Handler, commands []string, authKeyPath string) *ServerConnection { dlog.Client.Debug(server, "Creating new connection", server, handler, commands) sshConnectTimeout := time.Duration(config.Common.SSHConnectTimeoutMs) * time.Millisecond @@ -51,6 +54,7 @@ func NewServerConnection(server string, userName string, server: server, handler: handler, commands: commands, + authKeyPath: resolveAuthKeyPath(authKeyPath), config: &ssh.ClientConfig{ User: userName, Auth: authMethods, @@ -224,6 +228,7 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc dlog.Client.Debug(err) } } + c.sendAuthKeyRegistrationCommand() if !c.throttlingDone { dlog.Client.Debug(c.server, "Unthrottling connection (2)", @@ -236,3 +241,55 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc c.handler.Shutdown() return nil } + +func resolveAuthKeyPath(authKeyPath string) string { + if strings.TrimSpace(authKeyPath) != "" { + return authKeyPath + } + return os.Getenv("HOME") + "/.ssh/id_rsa" +} + +func (c *ServerConnection) sendAuthKeyRegistrationCommand() { + authKeyPubPath := c.authKeyPath + ".pub" + authKeyPubBytes, err := os.ReadFile(authKeyPubPath) + if err != nil { + dlog.Client.Debug(c.server, "Skipping AUTHKEY registration, unable to read public key", authKeyPubPath, err) + return + } + + authKeyBase64, err := extractAuthKeyBase64(authKeyPubBytes) + if err != nil { + dlog.Client.Debug(c.server, "Skipping AUTHKEY registration, invalid public key file", authKeyPubPath, err) + return + } + + if err := c.handler.SendMessage("AUTHKEY " + authKeyBase64); err != nil { + dlog.Client.Debug(c.server, "Unable to send AUTHKEY registration command", err) + return + } + dlog.Client.Debug(c.server, "Sent AUTHKEY registration command", authKeyPubPath) +} + +func extractAuthKeyBase64(authKeyPubBytes []byte) (string, error) { + authKeyPubContent := string(authKeyPubBytes) + for _, line := range strings.Split(authKeyPubContent, "\n") { + trimmedLine := strings.TrimSpace(line) + if trimmedLine == "" || strings.HasPrefix(trimmedLine, "#") { + continue + } + + fields := strings.Fields(trimmedLine) + if len(fields) < 2 { + return "", fmt.Errorf("expected authorized key format '<type> <base64-key> [comment]'") + } + + authKeyBase64 := strings.TrimSpace(fields[1]) + if _, err := base64.StdEncoding.DecodeString(authKeyBase64); err != nil { + return "", fmt.Errorf("invalid base64 public key: %w", err) + } + + return authKeyBase64, nil + } + + return "", fmt.Errorf("no public key found") +} diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go new file mode 100644 index 0000000..8ab126b --- /dev/null +++ b/internal/clients/connectors/serverconnection_test.go @@ -0,0 +1,112 @@ +package connectors + +import ( + "os" + "path/filepath" + "testing" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/io/dlog" +) + +func TestExtractAuthKeyBase64(t *testing.T) { + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + dlog.Client = originalLogger + }) + + t.Run("valid authorized key line", func(t *testing.T) { + pubKey := []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA user@host\n") + + got, err := extractAuthKeyBase64(pubKey) + if err != nil { + t.Fatalf("Expected valid key, got error: %v", err) + } + if got != "AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" { + t.Fatalf("Unexpected base64 payload: %s", got) + } + }) + + t.Run("invalid key format", func(t *testing.T) { + _, err := extractAuthKeyBase64([]byte("not-a-valid-authorized-key-line")) + if err == nil { + t.Fatalf("Expected parse error for invalid key format") + } + }) + + t.Run("invalid base64 payload", func(t *testing.T) { + _, err := extractAuthKeyBase64([]byte("ssh-ed25519 !!! not-valid\n")) + if err == nil { + t.Fatalf("Expected error for invalid base64 payload") + } + }) +} + +func TestSendAuthKeyRegistrationCommand(t *testing.T) { + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + dlog.Client = originalLogger + }) + + tempDir := t.TempDir() + privateKeyPath := filepath.Join(tempDir, "id_rsa") + publicKeyPath := privateKeyPath + ".pub" + if err := os.WriteFile(publicKeyPath, + []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA user@host\n"), 0600); err != nil { + t.Fatalf("Unable to write public key test file: %v", err) + } + + handler := &mockHandler{} + conn := &ServerConnection{ + server: "srv1", + handler: handler, + authKeyPath: privateKeyPath, + } + + conn.sendAuthKeyRegistrationCommand() + + if len(handler.commands) != 1 { + t.Fatalf("Expected one AUTHKEY command, got %d", len(handler.commands)) + } + expected := "AUTHKEY AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + if handler.commands[0] != expected { + t.Fatalf("Unexpected AUTHKEY command.\nexpected: %s\ngot: %s", expected, handler.commands[0]) + } +} + +type mockHandler struct { + commands []string +} + +var _ handlers.Handler = (*mockHandler)(nil) + +func (m *mockHandler) SendMessage(command string) error { + m.commands = append(m.commands, command) + return nil +} + +func (m *mockHandler) Server() string { + return "mock" +} + +func (m *mockHandler) Status() int { + return 0 +} + +func (m *mockHandler) Shutdown() {} + +func (m *mockHandler) Done() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} + +func (m *mockHandler) Read(_ []byte) (int, error) { + return 0, nil +} + +func (m *mockHandler) Write(p []byte) (int, error) { + return len(p), nil +} diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 1a500dc..1debb98 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -90,6 +90,9 @@ func (h *baseHandler) handleMessage(message string) { h.handleHiddenMessage(message) return } + if h.handleAuthKeyMessage(message) { + return + } // Add newline only if the message doesn't already end with one if len(message) > 0 && message[len(message)-1] == '\n' { @@ -99,6 +102,48 @@ func (h *baseHandler) handleMessage(message string) { } } +func (h *baseHandler) handleAuthKeyMessage(message string) bool { + isAuthKeyMessage, authKeyOK, authKeyDetail := parseAuthKeyMessage(message) + if !isAuthKeyMessage { + return false + } + + if authKeyOK { + dlog.Client.Debug(h.server, "AUTHKEY registration accepted by server") + return true + } + + if authKeyDetail == "" { + dlog.Client.Warn(h.server, "AUTHKEY registration failed") + return true + } + + dlog.Client.Warn(h.server, "AUTHKEY registration failed", authKeyDetail) + return true +} + +func parseAuthKeyMessage(message string) (isAuthKeyMessage bool, ok bool, detail string) { + if message == "" { + return false, false, "" + } + + payload := strings.TrimSpace(message) + parts := strings.Split(payload, protocol.FieldDelimiter) + if len(parts) > 0 { + payload = strings.TrimSpace(parts[len(parts)-1]) + } + + switch { + case payload == "AUTHKEY OK": + return true, true, "" + case strings.HasPrefix(payload, "AUTHKEY ERR"): + detail := strings.TrimSpace(strings.TrimPrefix(payload, "AUTHKEY ERR")) + return true, false, detail + default: + return false, false, "" + } +} + // Handle messages received from server which are not meant to be displayed // to the end user. func (h *baseHandler) handleHiddenMessage(message string) { diff --git a/internal/clients/handlers/basehandler_test.go b/internal/clients/handlers/basehandler_test.go new file mode 100644 index 0000000..996cd23 --- /dev/null +++ b/internal/clients/handlers/basehandler_test.go @@ -0,0 +1,59 @@ +package handlers + +import ( + "fmt" + "testing" + + "github.com/mimecast/dtail/internal/protocol" +) + +func TestParseAuthKeyMessage(t *testing.T) { + tests := []struct { + name string + message string + wantAuth bool + wantOK bool + wantInfo string + }{ + { + name: "server formatted success", + message: fmt.Sprintf("SERVER%s%s%sAUTHKEY OK\n", protocol.FieldDelimiter, "host1", protocol.FieldDelimiter), + wantAuth: true, + wantOK: true, + }, + { + name: "server formatted error", + message: fmt.Sprintf("SERVER%s%s%sAUTHKEY ERR feature disabled\n", protocol.FieldDelimiter, "host1", protocol.FieldDelimiter), + wantAuth: true, + wantOK: false, + wantInfo: "feature disabled", + }, + { + name: "plain response success", + message: "AUTHKEY OK", + wantAuth: true, + wantOK: true, + }, + { + name: "not an authkey message", + message: fmt.Sprintf("SERVER%s%s%ssome other message", protocol.FieldDelimiter, "host1", protocol.FieldDelimiter), + wantAuth: false, + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gotAuth, gotOK, gotInfo := parseAuthKeyMessage(tc.message) + if gotAuth != tc.wantAuth { + t.Fatalf("Unexpected auth marker: got %v want %v", gotAuth, tc.wantAuth) + } + if gotOK != tc.wantOK { + t.Fatalf("Unexpected ok marker: got %v want %v", gotOK, tc.wantOK) + } + if gotInfo != tc.wantInfo { + t.Fatalf("Unexpected info: got %q want %q", gotInfo, tc.wantInfo) + } + }) + } +} |
