summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-02-03 17:09:18 +0200
committerPaul Buetow <pbuetow@mimecast.com>2026-02-03 17:09:34 +0200
commitd89b9e6760e2aadf9779faa6f23678f67c731e1e (patch)
tree5e5136a70a0fd2f315c4751c31629fd97de4ece9
parent4cbd559c5d66a82358029dc4b00f5174c94c8ebc (diff)
Add SSH agent key selection and fix MapReduce outfile handling
This commit adds two major features and fixes: 1. SSH Agent Key Selection: - Add --agentKeyIndex flag to select specific SSH agent key (0-based) - Solves "too many authentication failures" with multiple SSH keys - Default -1 uses all keys (backwards compatible) - Available in dtail, dcat, dgrep, dmap commands 2. MapReduce Outfile Fixes: - CSV files now written at every interval, not just on exit - Proper signal handling (SIGTERM/SIGINT) with graceful shutdown - 5-second grace period for cleanup before force exit - Fixes issue where outfile remained as .tmp during execution Usage: dtail --servers host --agentKeyIndex 0 --query '...' outfile results.csv This is particularly useful with YubiKey/hardware tokens where many keys are loaded in the SSH agent, and for monitoring MapReduce results in real-time as they're computed. Co-authored-by: Cursor <cursoragent@cursor.com>
-rw-r--r--cmd/dcat/main.go1
-rw-r--r--cmd/dgrep/main.go1
-rw-r--r--cmd/dmap/main.go6
-rw-r--r--cmd/dtail/main.go3
-rw-r--r--internal/clients/baseclient.go2
-rw-r--r--internal/clients/maprclient.go8
-rw-r--r--internal/config/args.go2
-rw-r--r--internal/io/signal/signal.go45
-rw-r--r--internal/mapr/groupsetresult.go7
-rw-r--r--internal/ssh/client/authmethods.go8
-rw-r--r--internal/ssh/ssh.go30
-rw-r--r--internal/version/version.go2
12 files changed, 103 insertions, 12 deletions
diff --git a/cmd/dcat/main.go b/cmd/dcat/main.go
index 0c66a98..e2736d6 100644
--- a/cmd/dcat/main.go
+++ b/cmd/dcat/main.go
@@ -36,6 +36,7 @@ func main() {
flag.BoolVar(&displayVersion, "version", false, "Display version")
flag.IntVar(&args.ConnectionsPerCPU, "cpc", config.DefaultConnectionsPerCPU,
"How many connections established per CPU core concurrently")
+ flag.IntVar(&args.SSHAgentKeyIndex, "agentKeyIndex", -1, "SSH agent key index to use (-1 for all keys)")
flag.IntVar(&args.SSHPort, "port", config.DefaultSSHPort, "SSH server port")
flag.StringVar(&args.ConfigFile, "cfg", "", "Config file path")
flag.StringVar(&args.Discovery, "discovery", "", "Server discovery method")
diff --git a/cmd/dgrep/main.go b/cmd/dgrep/main.go
index 14cfb0c..c0a91eb 100644
--- a/cmd/dgrep/main.go
+++ b/cmd/dgrep/main.go
@@ -40,6 +40,7 @@ func main() {
flag.IntVar(&args.LContext.AfterContext, "after", 0, "Print lines of trailing context after matching lines")
flag.IntVar(&args.LContext.BeforeContext, "before", 0, "Print lines of leading context before matching lines")
flag.IntVar(&args.LContext.MaxCount, "max", 0, "Stop reading file after NUM matching lines")
+ flag.IntVar(&args.SSHAgentKeyIndex, "agentKeyIndex", -1, "SSH agent key index to use (-1 for all keys)")
flag.IntVar(&args.SSHPort, "port", config.DefaultSSHPort, "SSH server port")
flag.StringVar(&args.ConfigFile, "cfg", "", "Config file path")
flag.StringVar(&args.Discovery, "discovery", "", "Server discovery method")
diff --git a/cmd/dmap/main.go b/cmd/dmap/main.go
index 7500ea6..ea5f020 100644
--- a/cmd/dmap/main.go
+++ b/cmd/dmap/main.go
@@ -28,7 +28,8 @@ func main() {
var profileFlags profiling.Flags
args := config.Args{
- Mode: omode.MapClient,
+ Mode: omode.MapClient,
+ SSHAgentKeyIndex: -1,
}
userName := user.Name()
@@ -39,6 +40,7 @@ func main() {
flag.BoolVar(&displayVersion, "version", false, "Display version")
flag.IntVar(&args.ConnectionsPerCPU, "cpc", config.DefaultConnectionsPerCPU,
"How many connections established per CPU core concurrently")
+ flag.IntVar(&args.SSHAgentKeyIndex, "agentKeyIndex", -1, "SSH agent key index to use (-1 for all keys)")
flag.IntVar(&args.SSHPort, "port", config.DefaultSSHPort, "SSH server port")
flag.IntVar(&args.Timeout, "timeout", 0, "Max time dtail server will collect data until disconnection")
flag.StringVar(&args.ConfigFile, "cfg", "", "Config file path")
@@ -89,7 +91,7 @@ func main() {
dlog.Client.FatalPanic(err)
}
- status := client.Start(ctx, signal.InterruptCh(ctx))
+ status := client.Start(ctx, signal.InterruptChWithCancel(ctx, cancel))
// Log final metrics if profiling is enabled
if profileFlags.Enabled() {
diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go
index 3d363cb..188f518 100644
--- a/cmd/dtail/main.go
+++ b/cmd/dtail/main.go
@@ -51,6 +51,7 @@ func main() {
flag.IntVar(&args.LContext.AfterContext, "after", 0, "Print lines of trailing context after matching lines")
flag.IntVar(&args.LContext.BeforeContext, "before", 0, "Print lines of leading context before matching lines")
flag.IntVar(&args.LContext.MaxCount, "max", 0, "Stop reading file after NUM matching lines")
+ flag.IntVar(&args.SSHAgentKeyIndex, "agentKeyIndex", -1, "SSH agent key index to use (-1 for all keys)")
flag.IntVar(&args.SSHPort, "port", config.DefaultSSHPort, "SSH server port")
flag.IntVar(&args.Timeout, "timeout", 0, "Max time dtail server will collect data until disconnection")
flag.IntVar(&shutdownAfter, "shutdownAfter", 3600*24, "Shutdown after so many seconds")
@@ -137,7 +138,7 @@ func main() {
}
}
- status := client.Start(ctx, signal.InterruptCh(ctx))
+ status := client.Start(ctx, signal.InterruptChWithCancel(ctx, cancel))
// Log final metrics if profiling is enabled
if profileFlags.Enabled() {
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 013f2f2..e7a13f5 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -54,7 +54,7 @@ func (c *baseClient) init() {
}
c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(
c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts,
- c.throttleCh, c.Args.SSHPrivateKeyFilePath)
+ c.throttleCh, c.Args.SSHPrivateKeyFilePath, c.Args.SSHAgentKeyIndex)
}
func (c *baseClient) makeConnections(maker maker) {
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
index 226f76c..cfbffee 100644
--- a/internal/clients/maprclient.go
+++ b/internal/clients/maprclient.go
@@ -99,9 +99,12 @@ func (c *MaprClient) Start(ctx context.Context, statsCh <-chan string) (status i
go c.periodicReportResults(ctx)
status = c.baseClient.Start(ctx, statsCh)
+
+ // Always write final result for cumulative mode (includes outfile case)
if c.cumulative {
- dlog.Client.Debug("Received final mapreduce result")
+ dlog.Client.Debug("Writing final mapreduce result")
c.reportResults(true)
+ dlog.Client.Debug("Final result written")
}
return
@@ -210,13 +213,16 @@ func (c *MaprClient) printResults() {
}
func (c *MaprClient) writeResultsToOutfile(finalResult bool) {
+ dlog.Client.Debug("writeResultsToOutfile called", "finalResult", finalResult, "cumulative", c.cumulative)
if c.cumulative {
if err := c.globalGroup.WriteResult(c.query, finalResult); err != nil {
dlog.Client.FatalPanic(err)
}
+ dlog.Client.Debug("WriteResult completed for cumulative mode")
return
}
if err := c.globalGroup.SwapOut().WriteResult(c.query, true); err != nil {
dlog.Client.FatalPanic(err)
}
+ dlog.Client.Debug("WriteResult completed for non-cumulative mode")
}
diff --git a/internal/config/args.go b/internal/config/args.go
index 87ef393..a026e1c 100644
--- a/internal/config/args.go
+++ b/internal/config/args.go
@@ -28,6 +28,7 @@ type Args struct {
Quiet bool
RegexInvert bool
RegexStr string
+ SSHAgentKeyIndex int
SSHAuthMethods []gossh.AuthMethod
SSHBindAddress string
SSHHostKeyCallback gossh.HostKeyCallback
@@ -60,6 +61,7 @@ func (a *Args) String() string {
sb.WriteString(fmt.Sprintf("%s:%v,", "Quiet", a.Quiet))
sb.WriteString(fmt.Sprintf("%s:%v,", "RegexInvert", a.RegexInvert))
sb.WriteString(fmt.Sprintf("%s:%v,", "RegexStr", a.RegexStr))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "SSHAgentKeyIndex", a.SSHAgentKeyIndex))
sb.WriteString(fmt.Sprintf("%s:%v,", "SSHAuthMethods", a.SSHAuthMethods))
sb.WriteString(fmt.Sprintf("%s:%v,", "SSHBindAddress", a.SSHBindAddress))
sb.WriteString(fmt.Sprintf("%s:%v,", "SSHHostKeyCallback", a.SSHHostKeyCallback))
diff --git a/internal/io/signal/signal.go b/internal/io/signal/signal.go
index c01e82a..e94de42 100644
--- a/internal/io/signal/signal.go
+++ b/internal/io/signal/signal.go
@@ -10,7 +10,52 @@ import (
"github.com/mimecast/dtail/internal/config"
)
+// InterruptChWithCancel returns a channel for "please print stats" signalling.
+// It accepts a cancel function to properly shutdown when termination signals are received.
+func InterruptChWithCancel(ctx context.Context, cancel context.CancelFunc) <-chan string {
+ sigIntCh := make(chan os.Signal, 10)
+ gosignal.Notify(sigIntCh, os.Interrupt)
+ sigOtherCh := make(chan os.Signal, 10)
+ gosignal.Notify(sigOtherCh, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGQUIT)
+ statsCh := make(chan string)
+
+ go func() {
+ for {
+ select {
+ case <-sigIntCh:
+ select {
+ case statsCh <- "Hint: Hit Ctrl+C again to exit":
+ select {
+ case <-sigIntCh:
+ cancel()
+ // Wait longer to allow MapReduce cleanup, then force exit if still running
+ go func() {
+ time.Sleep(5 * time.Second)
+ os.Exit(0)
+ }()
+ case <-time.After(time.Second * time.Duration(config.InterruptTimeoutS)):
+ }
+ default:
+ // Stats already printed.
+ }
+ case <-sigOtherCh:
+ // Cancel context to allow graceful shutdown (MapReduce outfile cleanup, etc.)
+ cancel()
+ // Wait longer to allow MapReduce cleanup, then force exit if still running
+ go func() {
+ time.Sleep(5 * time.Second)
+ os.Exit(0)
+ }()
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+ return statsCh
+}
+
// InterruptCh returns a channel for "please print stats" signalling.
+// Deprecated: Use InterruptChWithCancel for proper cleanup on termination signals.
func InterruptCh(ctx context.Context) <-chan string {
sigIntCh := make(chan os.Signal, 10)
gosignal.Notify(sigIntCh, os.Interrupt)
diff --git a/internal/mapr/groupsetresult.go b/internal/mapr/groupsetresult.go
index 47bdab8..26e4b12 100644
--- a/internal/mapr/groupsetresult.go
+++ b/internal/mapr/groupsetresult.go
@@ -248,12 +248,17 @@ func (g *GroupSet) resultWriteUnformatted(query *Query, rows []result, fd *os.Fi
}
}
- if !query.Outfile.AppendMode && finalResult {
+ // Always rename .tmp to .csv after writing (not just on final result)
+ // This ensures the .csv file is updated at every interval
+ if !query.Outfile.AppendMode {
tmpOutfile := fmt.Sprintf("%s.tmp", query.Outfile.FilePath)
+ dlog.Common.Debug("Renaming outfile", tmpOutfile, "to", query.Outfile.FilePath)
if err := os.Rename(tmpOutfile, query.Outfile.FilePath); err != nil {
+ dlog.Common.Error("Failed to rename outfile", tmpOutfile, "error", err)
os.Remove(tmpOutfile)
return err
}
+ dlog.Common.Info("Successfully renamed outfile to", query.Outfile.FilePath)
}
return nil
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index 6128018..1a4cb3f 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -16,7 +16,7 @@ const addedPathStr string = "Added path to list of auth methods, not adding furt
// InitSSHAuthMethods initialises all known SSH auth methods on the client side.
func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{},
- privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
+ privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) {
if len(sshAuthMethods) > 0 {
simpleCallback, err := NewSimpleCallback()
@@ -25,7 +25,7 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
}
return sshAuthMethods, simpleCallback
}
- return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath)
+ return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex)
}
func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod {
@@ -44,7 +44,7 @@ func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod {
}
func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
- privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
+ privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) {
var sshAuthMethods []gossh.AuthMethod
knownHostsFile := fmt.Sprintf("%s/.ssh/known_hosts", os.Getenv("HOME"))
@@ -75,7 +75,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
}
// Second, try SSH Agent
- authMethod, err := ssh.Agent()
+ authMethod, err := ssh.AgentWithKeyIndex(agentKeyIndex)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 32e01b3..41cce05 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -43,6 +43,12 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
// Agent used for SSH auth.
func Agent() (gossh.AuthMethod, error) {
+ return AgentWithKeyIndex(-1)
+}
+
+// AgentWithKeyIndex used for SSH auth with a specific key index from the agent.
+// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used.
+func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
return nil, fmt.Errorf("failed to connect to SSH agent: %w", err)
@@ -55,7 +61,29 @@ func Agent() (gossh.AuthMethod, error) {
for i, key := range keys {
dlog.Common.Debug("Public key", i, key)
}
- return gossh.PublicKeysCallback(agentClient.Signers), nil
+
+ // If no specific key index requested, use all keys (backwards compatible default)
+ if keyIndex < 0 {
+ return gossh.PublicKeysCallback(agentClient.Signers), nil
+ }
+
+ // Use only the specified key index (0-based)
+ if keyIndex >= len(keys) {
+ return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys))
+ }
+
+ dlog.Common.Debug("Using SSH agent key at index", keyIndex)
+ return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) {
+ signers, err := agentClient.Signers()
+ if err != nil {
+ return nil, err
+ }
+ if keyIndex >= len(signers) {
+ return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers))
+ }
+ // Return only the specified signer
+ return []gossh.Signer{signers[keyIndex]}, nil
+ }), nil
}
// EnterKeyPhrase is required to read phrase protected private keys.
diff --git a/internal/version/version.go b/internal/version/version.go
index c8ad394..59a41bc 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -13,7 +13,7 @@ const (
// Name of DTail.
Name string = "DTail"
// Version of DTail.
- Version string = "4.3.2"
+ Version string = "4.3.2-cb"
// Additional information for DTail
Additional string = "Have a lot of fun!"
)