diff options
| author | Paul Buetow <paul@buetow.org> | 2026-02-03 17:09:18 +0200 |
|---|---|---|
| committer | Paul Buetow <pbuetow@mimecast.com> | 2026-02-03 17:09:34 +0200 |
| commit | d89b9e6760e2aadf9779faa6f23678f67c731e1e (patch) | |
| tree | 5e5136a70a0fd2f315c4751c31629fd97de4ece9 | |
| parent | 4cbd559c5d66a82358029dc4b00f5174c94c8ebc (diff) | |
Add SSH agent key selection and fix MapReduce outfile handling
This commit adds two major features and fixes:
1. SSH Agent Key Selection:
- Add --agentKeyIndex flag to select specific SSH agent key (0-based)
- Solves "too many authentication failures" with multiple SSH keys
- Default -1 uses all keys (backwards compatible)
- Available in dtail, dcat, dgrep, dmap commands
2. MapReduce Outfile Fixes:
- CSV files now written at every interval, not just on exit
- Proper signal handling (SIGTERM/SIGINT) with graceful shutdown
- 5-second grace period for cleanup before force exit
- Fixes issue where outfile remained as .tmp during execution
Usage:
dtail --servers host --agentKeyIndex 0 --query '...' outfile results.csv
This is particularly useful with YubiKey/hardware tokens where many
keys are loaded in the SSH agent, and for monitoring MapReduce results
in real-time as they're computed.
Co-authored-by: Cursor <cursoragent@cursor.com>
| -rw-r--r-- | cmd/dcat/main.go | 1 | ||||
| -rw-r--r-- | cmd/dgrep/main.go | 1 | ||||
| -rw-r--r-- | cmd/dmap/main.go | 6 | ||||
| -rw-r--r-- | cmd/dtail/main.go | 3 | ||||
| -rw-r--r-- | internal/clients/baseclient.go | 2 | ||||
| -rw-r--r-- | internal/clients/maprclient.go | 8 | ||||
| -rw-r--r-- | internal/config/args.go | 2 | ||||
| -rw-r--r-- | internal/io/signal/signal.go | 45 | ||||
| -rw-r--r-- | internal/mapr/groupsetresult.go | 7 | ||||
| -rw-r--r-- | internal/ssh/client/authmethods.go | 8 | ||||
| -rw-r--r-- | internal/ssh/ssh.go | 30 | ||||
| -rw-r--r-- | internal/version/version.go | 2 |
12 files changed, 103 insertions, 12 deletions
diff --git a/cmd/dcat/main.go b/cmd/dcat/main.go index 0c66a98..e2736d6 100644 --- a/cmd/dcat/main.go +++ b/cmd/dcat/main.go @@ -36,6 +36,7 @@ func main() { flag.BoolVar(&displayVersion, "version", false, "Display version") flag.IntVar(&args.ConnectionsPerCPU, "cpc", config.DefaultConnectionsPerCPU, "How many connections established per CPU core concurrently") + flag.IntVar(&args.SSHAgentKeyIndex, "agentKeyIndex", -1, "SSH agent key index to use (-1 for all keys)") flag.IntVar(&args.SSHPort, "port", config.DefaultSSHPort, "SSH server port") flag.StringVar(&args.ConfigFile, "cfg", "", "Config file path") flag.StringVar(&args.Discovery, "discovery", "", "Server discovery method") diff --git a/cmd/dgrep/main.go b/cmd/dgrep/main.go index 14cfb0c..c0a91eb 100644 --- a/cmd/dgrep/main.go +++ b/cmd/dgrep/main.go @@ -40,6 +40,7 @@ func main() { flag.IntVar(&args.LContext.AfterContext, "after", 0, "Print lines of trailing context after matching lines") flag.IntVar(&args.LContext.BeforeContext, "before", 0, "Print lines of leading context before matching lines") flag.IntVar(&args.LContext.MaxCount, "max", 0, "Stop reading file after NUM matching lines") + flag.IntVar(&args.SSHAgentKeyIndex, "agentKeyIndex", -1, "SSH agent key index to use (-1 for all keys)") flag.IntVar(&args.SSHPort, "port", config.DefaultSSHPort, "SSH server port") flag.StringVar(&args.ConfigFile, "cfg", "", "Config file path") flag.StringVar(&args.Discovery, "discovery", "", "Server discovery method") diff --git a/cmd/dmap/main.go b/cmd/dmap/main.go index 7500ea6..ea5f020 100644 --- a/cmd/dmap/main.go +++ b/cmd/dmap/main.go @@ -28,7 +28,8 @@ func main() { var profileFlags profiling.Flags args := config.Args{ - Mode: omode.MapClient, + Mode: omode.MapClient, + SSHAgentKeyIndex: -1, } userName := user.Name() @@ -39,6 +40,7 @@ func main() { flag.BoolVar(&displayVersion, "version", false, "Display version") flag.IntVar(&args.ConnectionsPerCPU, "cpc", config.DefaultConnectionsPerCPU, "How many connections established per CPU core concurrently") + flag.IntVar(&args.SSHAgentKeyIndex, "agentKeyIndex", -1, "SSH agent key index to use (-1 for all keys)") flag.IntVar(&args.SSHPort, "port", config.DefaultSSHPort, "SSH server port") flag.IntVar(&args.Timeout, "timeout", 0, "Max time dtail server will collect data until disconnection") flag.StringVar(&args.ConfigFile, "cfg", "", "Config file path") @@ -89,7 +91,7 @@ func main() { dlog.Client.FatalPanic(err) } - status := client.Start(ctx, signal.InterruptCh(ctx)) + status := client.Start(ctx, signal.InterruptChWithCancel(ctx, cancel)) // Log final metrics if profiling is enabled if profileFlags.Enabled() { diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go index 3d363cb..188f518 100644 --- a/cmd/dtail/main.go +++ b/cmd/dtail/main.go @@ -51,6 +51,7 @@ func main() { flag.IntVar(&args.LContext.AfterContext, "after", 0, "Print lines of trailing context after matching lines") flag.IntVar(&args.LContext.BeforeContext, "before", 0, "Print lines of leading context before matching lines") flag.IntVar(&args.LContext.MaxCount, "max", 0, "Stop reading file after NUM matching lines") + flag.IntVar(&args.SSHAgentKeyIndex, "agentKeyIndex", -1, "SSH agent key index to use (-1 for all keys)") flag.IntVar(&args.SSHPort, "port", config.DefaultSSHPort, "SSH server port") flag.IntVar(&args.Timeout, "timeout", 0, "Max time dtail server will collect data until disconnection") flag.IntVar(&shutdownAfter, "shutdownAfter", 3600*24, "Shutdown after so many seconds") @@ -137,7 +138,7 @@ func main() { } } - status := client.Start(ctx, signal.InterruptCh(ctx)) + status := client.Start(ctx, signal.InterruptChWithCancel(ctx, cancel)) // Log final metrics if profiling is enabled if profileFlags.Enabled() { diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 013f2f2..e7a13f5 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -54,7 +54,7 @@ func (c *baseClient) init() { } c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods( c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, - c.throttleCh, c.Args.SSHPrivateKeyFilePath) + c.throttleCh, c.Args.SSHPrivateKeyFilePath, c.Args.SSHAgentKeyIndex) } func (c *baseClient) makeConnections(maker maker) { diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go index 226f76c..cfbffee 100644 --- a/internal/clients/maprclient.go +++ b/internal/clients/maprclient.go @@ -99,9 +99,12 @@ func (c *MaprClient) Start(ctx context.Context, statsCh <-chan string) (status i go c.periodicReportResults(ctx) status = c.baseClient.Start(ctx, statsCh) + + // Always write final result for cumulative mode (includes outfile case) if c.cumulative { - dlog.Client.Debug("Received final mapreduce result") + dlog.Client.Debug("Writing final mapreduce result") c.reportResults(true) + dlog.Client.Debug("Final result written") } return @@ -210,13 +213,16 @@ func (c *MaprClient) printResults() { } func (c *MaprClient) writeResultsToOutfile(finalResult bool) { + dlog.Client.Debug("writeResultsToOutfile called", "finalResult", finalResult, "cumulative", c.cumulative) if c.cumulative { if err := c.globalGroup.WriteResult(c.query, finalResult); err != nil { dlog.Client.FatalPanic(err) } + dlog.Client.Debug("WriteResult completed for cumulative mode") return } if err := c.globalGroup.SwapOut().WriteResult(c.query, true); err != nil { dlog.Client.FatalPanic(err) } + dlog.Client.Debug("WriteResult completed for non-cumulative mode") } diff --git a/internal/config/args.go b/internal/config/args.go index 87ef393..a026e1c 100644 --- a/internal/config/args.go +++ b/internal/config/args.go @@ -28,6 +28,7 @@ type Args struct { Quiet bool RegexInvert bool RegexStr string + SSHAgentKeyIndex int SSHAuthMethods []gossh.AuthMethod SSHBindAddress string SSHHostKeyCallback gossh.HostKeyCallback @@ -60,6 +61,7 @@ func (a *Args) String() string { sb.WriteString(fmt.Sprintf("%s:%v,", "Quiet", a.Quiet)) sb.WriteString(fmt.Sprintf("%s:%v,", "RegexInvert", a.RegexInvert)) sb.WriteString(fmt.Sprintf("%s:%v,", "RegexStr", a.RegexStr)) + sb.WriteString(fmt.Sprintf("%s:%v,", "SSHAgentKeyIndex", a.SSHAgentKeyIndex)) sb.WriteString(fmt.Sprintf("%s:%v,", "SSHAuthMethods", a.SSHAuthMethods)) sb.WriteString(fmt.Sprintf("%s:%v,", "SSHBindAddress", a.SSHBindAddress)) sb.WriteString(fmt.Sprintf("%s:%v,", "SSHHostKeyCallback", a.SSHHostKeyCallback)) diff --git a/internal/io/signal/signal.go b/internal/io/signal/signal.go index c01e82a..e94de42 100644 --- a/internal/io/signal/signal.go +++ b/internal/io/signal/signal.go @@ -10,7 +10,52 @@ import ( "github.com/mimecast/dtail/internal/config" ) +// InterruptChWithCancel returns a channel for "please print stats" signalling. +// It accepts a cancel function to properly shutdown when termination signals are received. +func InterruptChWithCancel(ctx context.Context, cancel context.CancelFunc) <-chan string { + sigIntCh := make(chan os.Signal, 10) + gosignal.Notify(sigIntCh, os.Interrupt) + sigOtherCh := make(chan os.Signal, 10) + gosignal.Notify(sigOtherCh, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGQUIT) + statsCh := make(chan string) + + go func() { + for { + select { + case <-sigIntCh: + select { + case statsCh <- "Hint: Hit Ctrl+C again to exit": + select { + case <-sigIntCh: + cancel() + // Wait longer to allow MapReduce cleanup, then force exit if still running + go func() { + time.Sleep(5 * time.Second) + os.Exit(0) + }() + case <-time.After(time.Second * time.Duration(config.InterruptTimeoutS)): + } + default: + // Stats already printed. + } + case <-sigOtherCh: + // Cancel context to allow graceful shutdown (MapReduce outfile cleanup, etc.) + cancel() + // Wait longer to allow MapReduce cleanup, then force exit if still running + go func() { + time.Sleep(5 * time.Second) + os.Exit(0) + }() + case <-ctx.Done(): + return + } + } + }() + return statsCh +} + // InterruptCh returns a channel for "please print stats" signalling. +// Deprecated: Use InterruptChWithCancel for proper cleanup on termination signals. func InterruptCh(ctx context.Context) <-chan string { sigIntCh := make(chan os.Signal, 10) gosignal.Notify(sigIntCh, os.Interrupt) diff --git a/internal/mapr/groupsetresult.go b/internal/mapr/groupsetresult.go index 47bdab8..26e4b12 100644 --- a/internal/mapr/groupsetresult.go +++ b/internal/mapr/groupsetresult.go @@ -248,12 +248,17 @@ func (g *GroupSet) resultWriteUnformatted(query *Query, rows []result, fd *os.Fi } } - if !query.Outfile.AppendMode && finalResult { + // Always rename .tmp to .csv after writing (not just on final result) + // This ensures the .csv file is updated at every interval + if !query.Outfile.AppendMode { tmpOutfile := fmt.Sprintf("%s.tmp", query.Outfile.FilePath) + dlog.Common.Debug("Renaming outfile", tmpOutfile, "to", query.Outfile.FilePath) if err := os.Rename(tmpOutfile, query.Outfile.FilePath); err != nil { + dlog.Common.Error("Failed to rename outfile", tmpOutfile, "error", err) os.Remove(tmpOutfile) return err } + dlog.Common.Info("Successfully renamed outfile to", query.Outfile.FilePath) } return nil diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 6128018..1a4cb3f 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -16,7 +16,7 @@ const addedPathStr string = "Added path to list of auth methods, not adding furt // InitSSHAuthMethods initialises all known SSH auth methods on the client side. func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, - privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { + privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { if len(sshAuthMethods) > 0 { simpleCallback, err := NewSimpleCallback() @@ -25,7 +25,7 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, } return sshAuthMethods, simpleCallback } - return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath) + return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex) } func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { @@ -44,7 +44,7 @@ func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { } func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, - privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { + privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { var sshAuthMethods []gossh.AuthMethod knownHostsFile := fmt.Sprintf("%s/.ssh/known_hosts", os.Getenv("HOME")) @@ -75,7 +75,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, } // Second, try SSH Agent - authMethod, err := ssh.Agent() + authMethod, err := ssh.AgentWithKeyIndex(agentKeyIndex) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+ diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 32e01b3..41cce05 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -43,6 +43,12 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { // Agent used for SSH auth. func Agent() (gossh.AuthMethod, error) { + return AgentWithKeyIndex(-1) +} + +// AgentWithKeyIndex used for SSH auth with a specific key index from the agent. +// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used. +func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) { sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) if err != nil { return nil, fmt.Errorf("failed to connect to SSH agent: %w", err) @@ -55,7 +61,29 @@ func Agent() (gossh.AuthMethod, error) { for i, key := range keys { dlog.Common.Debug("Public key", i, key) } - return gossh.PublicKeysCallback(agentClient.Signers), nil + + // If no specific key index requested, use all keys (backwards compatible default) + if keyIndex < 0 { + return gossh.PublicKeysCallback(agentClient.Signers), nil + } + + // Use only the specified key index (0-based) + if keyIndex >= len(keys) { + return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys)) + } + + dlog.Common.Debug("Using SSH agent key at index", keyIndex) + return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) { + signers, err := agentClient.Signers() + if err != nil { + return nil, err + } + if keyIndex >= len(signers) { + return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers)) + } + // Return only the specified signer + return []gossh.Signer{signers[keyIndex]}, nil + }), nil } // EnterKeyPhrase is required to read phrase protected private keys. diff --git a/internal/version/version.go b/internal/version/version.go index c8ad394..59a41bc 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -13,7 +13,7 @@ const ( // Name of DTail. Name string = "DTail" // Version of DTail. - Version string = "4.3.2" + Version string = "4.3.2-cb" // Additional information for DTail Additional string = "Have a lot of fun!" ) |
