summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-03 10:20:05 +0200
committerPaul Buetow <paul@buetow.org>2026-03-03 10:20:05 +0200
commit6d50a475114699911f2ebe1376915cd8317f1881 (patch)
tree09d742d296024d028e1b35e45cb93baaf95a15d7 /internal
parentf17ffe1bae2f176e4dda90ff4dd2cb267332a7b4 (diff)
feat(client): register AUTHKEY after SSH session start
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/baseclient.go3
-rw-r--r--internal/clients/connectors/serverconnection.go59
-rw-r--r--internal/clients/connectors/serverconnection_test.go112
-rw-r--r--internal/clients/handlers/basehandler.go45
-rw-r--r--internal/clients/handlers/basehandler_test.go59
5 files changed, 276 insertions, 2 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 95fa721..766f05d 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -203,5 +203,6 @@ func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMe
c.maker.makeCommands())
}
return connectors.NewServerConnection(server, c.UserName, sshAuthMethods,
- hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands())
+ hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands(),
+ c.Args.SSHPrivateKeyFilePath)
}
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 3c29ac0..ca1fc43 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -2,9 +2,11 @@ package connectors
import (
"context"
+ "encoding/base64"
"fmt"
"io"
"net"
+ "os"
"strconv"
"strings"
"time"
@@ -29,6 +31,7 @@ type ServerConnection struct {
config *ssh.ClientConfig
handler handlers.Handler
commands []string
+ authKeyPath string
hostKeyCallback client.HostKeyCallback
throttlingDone bool
}
@@ -38,7 +41,7 @@ var _ Connector = (*ServerConnection)(nil)
// NewServerConnection returns a new DTail SSH server connection.
func NewServerConnection(server string, userName string,
authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback,
- handler handlers.Handler, commands []string) *ServerConnection {
+ handler handlers.Handler, commands []string, authKeyPath string) *ServerConnection {
dlog.Client.Debug(server, "Creating new connection", server, handler, commands)
sshConnectTimeout := time.Duration(config.Common.SSHConnectTimeoutMs) * time.Millisecond
@@ -51,6 +54,7 @@ func NewServerConnection(server string, userName string,
server: server,
handler: handler,
commands: commands,
+ authKeyPath: resolveAuthKeyPath(authKeyPath),
config: &ssh.ClientConfig{
User: userName,
Auth: authMethods,
@@ -224,6 +228,7 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc
dlog.Client.Debug(err)
}
}
+ c.sendAuthKeyRegistrationCommand()
if !c.throttlingDone {
dlog.Client.Debug(c.server, "Unthrottling connection (2)",
@@ -236,3 +241,55 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc
c.handler.Shutdown()
return nil
}
+
+func resolveAuthKeyPath(authKeyPath string) string {
+ if strings.TrimSpace(authKeyPath) != "" {
+ return authKeyPath
+ }
+ return os.Getenv("HOME") + "/.ssh/id_rsa"
+}
+
+func (c *ServerConnection) sendAuthKeyRegistrationCommand() {
+ authKeyPubPath := c.authKeyPath + ".pub"
+ authKeyPubBytes, err := os.ReadFile(authKeyPubPath)
+ if err != nil {
+ dlog.Client.Debug(c.server, "Skipping AUTHKEY registration, unable to read public key", authKeyPubPath, err)
+ return
+ }
+
+ authKeyBase64, err := extractAuthKeyBase64(authKeyPubBytes)
+ if err != nil {
+ dlog.Client.Debug(c.server, "Skipping AUTHKEY registration, invalid public key file", authKeyPubPath, err)
+ return
+ }
+
+ if err := c.handler.SendMessage("AUTHKEY " + authKeyBase64); err != nil {
+ dlog.Client.Debug(c.server, "Unable to send AUTHKEY registration command", err)
+ return
+ }
+ dlog.Client.Debug(c.server, "Sent AUTHKEY registration command", authKeyPubPath)
+}
+
+func extractAuthKeyBase64(authKeyPubBytes []byte) (string, error) {
+ authKeyPubContent := string(authKeyPubBytes)
+ for _, line := range strings.Split(authKeyPubContent, "\n") {
+ trimmedLine := strings.TrimSpace(line)
+ if trimmedLine == "" || strings.HasPrefix(trimmedLine, "#") {
+ continue
+ }
+
+ fields := strings.Fields(trimmedLine)
+ if len(fields) < 2 {
+ return "", fmt.Errorf("expected authorized key format '<type> <base64-key> [comment]'")
+ }
+
+ authKeyBase64 := strings.TrimSpace(fields[1])
+ if _, err := base64.StdEncoding.DecodeString(authKeyBase64); err != nil {
+ return "", fmt.Errorf("invalid base64 public key: %w", err)
+ }
+
+ return authKeyBase64, nil
+ }
+
+ return "", fmt.Errorf("no public key found")
+}
diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go
new file mode 100644
index 0000000..8ab126b
--- /dev/null
+++ b/internal/clients/connectors/serverconnection_test.go
@@ -0,0 +1,112 @@
+package connectors
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/io/dlog"
+)
+
+func TestExtractAuthKeyBase64(t *testing.T) {
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ dlog.Client = originalLogger
+ })
+
+ t.Run("valid authorized key line", func(t *testing.T) {
+ pubKey := []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA user@host\n")
+
+ got, err := extractAuthKeyBase64(pubKey)
+ if err != nil {
+ t.Fatalf("Expected valid key, got error: %v", err)
+ }
+ if got != "AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" {
+ t.Fatalf("Unexpected base64 payload: %s", got)
+ }
+ })
+
+ t.Run("invalid key format", func(t *testing.T) {
+ _, err := extractAuthKeyBase64([]byte("not-a-valid-authorized-key-line"))
+ if err == nil {
+ t.Fatalf("Expected parse error for invalid key format")
+ }
+ })
+
+ t.Run("invalid base64 payload", func(t *testing.T) {
+ _, err := extractAuthKeyBase64([]byte("ssh-ed25519 !!! not-valid\n"))
+ if err == nil {
+ t.Fatalf("Expected error for invalid base64 payload")
+ }
+ })
+}
+
+func TestSendAuthKeyRegistrationCommand(t *testing.T) {
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ dlog.Client = originalLogger
+ })
+
+ tempDir := t.TempDir()
+ privateKeyPath := filepath.Join(tempDir, "id_rsa")
+ publicKeyPath := privateKeyPath + ".pub"
+ if err := os.WriteFile(publicKeyPath,
+ []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA user@host\n"), 0600); err != nil {
+ t.Fatalf("Unable to write public key test file: %v", err)
+ }
+
+ handler := &mockHandler{}
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: handler,
+ authKeyPath: privateKeyPath,
+ }
+
+ conn.sendAuthKeyRegistrationCommand()
+
+ if len(handler.commands) != 1 {
+ t.Fatalf("Expected one AUTHKEY command, got %d", len(handler.commands))
+ }
+ expected := "AUTHKEY AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
+ if handler.commands[0] != expected {
+ t.Fatalf("Unexpected AUTHKEY command.\nexpected: %s\ngot: %s", expected, handler.commands[0])
+ }
+}
+
+type mockHandler struct {
+ commands []string
+}
+
+var _ handlers.Handler = (*mockHandler)(nil)
+
+func (m *mockHandler) SendMessage(command string) error {
+ m.commands = append(m.commands, command)
+ return nil
+}
+
+func (m *mockHandler) Server() string {
+ return "mock"
+}
+
+func (m *mockHandler) Status() int {
+ return 0
+}
+
+func (m *mockHandler) Shutdown() {}
+
+func (m *mockHandler) Done() <-chan struct{} {
+ ch := make(chan struct{})
+ close(ch)
+ return ch
+}
+
+func (m *mockHandler) Read(_ []byte) (int, error) {
+ return 0, nil
+}
+
+func (m *mockHandler) Write(p []byte) (int, error) {
+ return len(p), nil
+}
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go
index 1a500dc..1debb98 100644
--- a/internal/clients/handlers/basehandler.go
+++ b/internal/clients/handlers/basehandler.go
@@ -90,6 +90,9 @@ func (h *baseHandler) handleMessage(message string) {
h.handleHiddenMessage(message)
return
}
+ if h.handleAuthKeyMessage(message) {
+ return
+ }
// Add newline only if the message doesn't already end with one
if len(message) > 0 && message[len(message)-1] == '\n' {
@@ -99,6 +102,48 @@ func (h *baseHandler) handleMessage(message string) {
}
}
+func (h *baseHandler) handleAuthKeyMessage(message string) bool {
+ isAuthKeyMessage, authKeyOK, authKeyDetail := parseAuthKeyMessage(message)
+ if !isAuthKeyMessage {
+ return false
+ }
+
+ if authKeyOK {
+ dlog.Client.Debug(h.server, "AUTHKEY registration accepted by server")
+ return true
+ }
+
+ if authKeyDetail == "" {
+ dlog.Client.Warn(h.server, "AUTHKEY registration failed")
+ return true
+ }
+
+ dlog.Client.Warn(h.server, "AUTHKEY registration failed", authKeyDetail)
+ return true
+}
+
+func parseAuthKeyMessage(message string) (isAuthKeyMessage bool, ok bool, detail string) {
+ if message == "" {
+ return false, false, ""
+ }
+
+ payload := strings.TrimSpace(message)
+ parts := strings.Split(payload, protocol.FieldDelimiter)
+ if len(parts) > 0 {
+ payload = strings.TrimSpace(parts[len(parts)-1])
+ }
+
+ switch {
+ case payload == "AUTHKEY OK":
+ return true, true, ""
+ case strings.HasPrefix(payload, "AUTHKEY ERR"):
+ detail := strings.TrimSpace(strings.TrimPrefix(payload, "AUTHKEY ERR"))
+ return true, false, detail
+ default:
+ return false, false, ""
+ }
+}
+
// Handle messages received from server which are not meant to be displayed
// to the end user.
func (h *baseHandler) handleHiddenMessage(message string) {
diff --git a/internal/clients/handlers/basehandler_test.go b/internal/clients/handlers/basehandler_test.go
new file mode 100644
index 0000000..996cd23
--- /dev/null
+++ b/internal/clients/handlers/basehandler_test.go
@@ -0,0 +1,59 @@
+package handlers
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/mimecast/dtail/internal/protocol"
+)
+
+func TestParseAuthKeyMessage(t *testing.T) {
+ tests := []struct {
+ name string
+ message string
+ wantAuth bool
+ wantOK bool
+ wantInfo string
+ }{
+ {
+ name: "server formatted success",
+ message: fmt.Sprintf("SERVER%s%s%sAUTHKEY OK\n", protocol.FieldDelimiter, "host1", protocol.FieldDelimiter),
+ wantAuth: true,
+ wantOK: true,
+ },
+ {
+ name: "server formatted error",
+ message: fmt.Sprintf("SERVER%s%s%sAUTHKEY ERR feature disabled\n", protocol.FieldDelimiter, "host1", protocol.FieldDelimiter),
+ wantAuth: true,
+ wantOK: false,
+ wantInfo: "feature disabled",
+ },
+ {
+ name: "plain response success",
+ message: "AUTHKEY OK",
+ wantAuth: true,
+ wantOK: true,
+ },
+ {
+ name: "not an authkey message",
+ message: fmt.Sprintf("SERVER%s%s%ssome other message", protocol.FieldDelimiter, "host1", protocol.FieldDelimiter),
+ wantAuth: false,
+ wantOK: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ gotAuth, gotOK, gotInfo := parseAuthKeyMessage(tc.message)
+ if gotAuth != tc.wantAuth {
+ t.Fatalf("Unexpected auth marker: got %v want %v", gotAuth, tc.wantAuth)
+ }
+ if gotOK != tc.wantOK {
+ t.Fatalf("Unexpected ok marker: got %v want %v", gotOK, tc.wantOK)
+ }
+ if gotInfo != tc.wantInfo {
+ t.Fatalf("Unexpected info: got %q want %q", gotInfo, tc.wantInfo)
+ }
+ })
+ }
+}