diff options
| -rw-r--r-- | cmd/dcat/main.go | 3 | ||||
| -rw-r--r-- | cmd/dgrep/main.go | 3 | ||||
| -rw-r--r-- | cmd/dmap/main.go | 3 | ||||
| -rw-r--r-- | cmd/dserver/main.go | 5 | ||||
| -rw-r--r-- | cmd/dtail/main.go | 3 | ||||
| -rw-r--r-- | cmd/dtailhealth/main.go | 2 | ||||
| -rw-r--r-- | internal/clients/baseclient.go | 10 | ||||
| -rw-r--r-- | internal/clients/catclient.go | 1 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection.go | 32 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection_test.go | 88 | ||||
| -rw-r--r-- | internal/clients/connectors/serverless.go | 47 | ||||
| -rw-r--r-- | internal/clients/grepclient.go | 1 | ||||
| -rw-r--r-- | internal/clients/healthclient.go | 1 | ||||
| -rw-r--r-- | internal/clients/maprclient.go | 14 | ||||
| -rw-r--r-- | internal/clients/runtime_boundary.go | 212 | ||||
| -rw-r--r-- | internal/clients/runtime_boundary_test.go | 89 | ||||
| -rw-r--r-- | internal/clients/stats.go | 23 | ||||
| -rw-r--r-- | internal/clients/tailclient.go | 1 | ||||
| -rw-r--r-- | internal/mapr/globalgroupset.go | 4 | ||||
| -rw-r--r-- | internal/mapr/groupsetresult.go | 101 | ||||
| -rw-r--r-- | internal/mapr/groupsetresult_renderer_test.go | 113 | ||||
| -rw-r--r-- | internal/mapr/result_renderer.go | 34 | ||||
| -rw-r--r-- | internal/version/version.go | 13 |
23 files changed, 654 insertions, 149 deletions
diff --git a/cmd/dcat/main.go b/cmd/dcat/main.go index df580c4..64c6f10 100644 --- a/cmd/dcat/main.go +++ b/cmd/dcat/main.go @@ -54,7 +54,8 @@ func main() { config.Setup(source.Client, &args, flag.Args()) if displayVersion { - version.PrintAndExit() + runtimeCfg := config.CurrentRuntime() + version.PrintAndExit(runtimeCfg.Client != nil && runtimeCfg.Client.TermColorsEnable) } runtime := cli.NewClientRuntime(context.Background(), profileFlags, "dcat") diff --git a/cmd/dgrep/main.go b/cmd/dgrep/main.go index a36a6ae..aadf5c7 100644 --- a/cmd/dgrep/main.go +++ b/cmd/dgrep/main.go @@ -60,7 +60,8 @@ func main() { config.Setup(source.Client, &args, flag.Args()) if displayVersion { - version.PrintAndExit() + runtimeCfg := config.CurrentRuntime() + version.PrintAndExit(runtimeCfg.Client != nil && runtimeCfg.Client.TermColorsEnable) } runtime := cli.NewClientRuntime(context.Background(), profileFlags, "dgrep") diff --git a/cmd/dmap/main.go b/cmd/dmap/main.go index 33ace65..e6665ee 100644 --- a/cmd/dmap/main.go +++ b/cmd/dmap/main.go @@ -60,7 +60,8 @@ func main() { config.Setup(source.Client, &args, flag.Args()) if displayVersion { - version.PrintAndExit() + runtimeCfg := config.CurrentRuntime() + version.PrintAndExit(runtimeCfg.Client != nil && runtimeCfg.Client.TermColorsEnable) } runtime := cli.NewClientRuntime(context.Background(), profileFlags, "dmap") diff --git a/cmd/dserver/main.go b/cmd/dserver/main.go index 7c13b29..b7ad091 100644 --- a/cmd/dserver/main.go +++ b/cmd/dserver/main.go @@ -46,9 +46,10 @@ func main() { config.Setup(source.Server, &args, flag.Args()) if displayVersion { - version.PrintAndExit() + runtimeCfg := config.CurrentRuntime() + version.PrintAndExit(runtimeCfg.Client != nil && runtimeCfg.Client.TermColorsEnable) } - version.Print() + version.Print(false) ctx, cancel := context.WithCancel(context.Background()) if shutdownAfter > 0 { diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go index 4ea27df..d3fe931 100644 --- a/cmd/dtail/main.go +++ b/cmd/dtail/main.go @@ -76,7 +76,8 @@ func main() { } config.Setup(source.Client, &args, flag.Args()) if displayVersion { - version.PrintAndExit() + runtimeCfg := config.CurrentRuntime() + version.PrintAndExit(runtimeCfg.Client != nil && runtimeCfg.Client.TermColorsEnable) } if !args.Plain { if displayWideColorTable { diff --git a/cmd/dtailhealth/main.go b/cmd/dtailhealth/main.go index f1fbf6f..821684f 100644 --- a/cmd/dtailhealth/main.go +++ b/cmd/dtailhealth/main.go @@ -33,7 +33,7 @@ func main() { flag.Parse() if displayVersion { - version.PrintAndExit() + version.PrintAndExit(false) } config.Setup(source.HealthCheck, &args, flag.Args()) diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index adc2b77..7358cb3 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -25,6 +25,7 @@ const ( // This is the main client data structure. type baseClient struct { config.Args + runtime *clientRuntimeBoundary // To display client side stats stats *stats // We have one connection per remote server. @@ -45,6 +46,9 @@ type baseClient struct { func (c *baseClient) init() { dlog.Client.Debug("Initiating base client", c.Args.String()) + if c.runtime == nil { + c.runtime = newClientRuntimeBoundary(config.CurrentRuntime()) + } flag := regex.Default if c.Args.RegexInvert { @@ -73,7 +77,7 @@ func (c *baseClient) makeConnections(maker maker) { c.sshAuthMethods, c.hostKeyCallback)) } - c.stats = newTailStats(len(c.connections)) + c.stats = newTailStats(len(c.connections), c.runtime.output, c.runtime.InterruptPause()) } func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status int) { @@ -200,9 +204,9 @@ func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMe hostKeyCallback client.HostKeyCallback) connectors.Connector { if c.Args.Serverless { return connectors.NewServerless(c.UserName, c.maker.makeHandler(server), - c.maker.makeCommands()) + c.maker.makeCommands(), c.runtime) } return connectors.NewServerConnection(server, c.UserName, sshAuthMethods, hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands(), - c.Args.SSHPrivateKeyFilePath, c.Args.NoAuthKey) + c.Args.SSHPrivateKeyFilePath, c.Args.NoAuthKey, c.runtime) } diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go index bd65560..0a62b9d 100644 --- a/internal/clients/catclient.go +++ b/internal/clients/catclient.go @@ -29,6 +29,7 @@ func NewCatClient(args config.Args) (*CatClient, error) { Args: args, throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, + runtime: newClientRuntimeBoundary(config.CurrentRuntime()), }, } diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 649fe30..1136bf9 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -12,13 +12,23 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/ssh/client" "golang.org/x/crypto/ssh" ) +// SSHSettings provides the connection settings needed by ServerConnection. +type SSHSettings interface { + SSHPort() int + SSHConnectTimeout() time.Duration +} + +const ( + defaultSSHConnectTimeout = 2 * time.Second + defaultSSHPort = 2222 +) + // ServerConnection represents a connection to a single remote dtail server via // SSH protocol. type ServerConnection struct { @@ -43,12 +53,20 @@ var _ Connector = (*ServerConnection)(nil) func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, handler handlers.Handler, commands []string, authKeyPath string, - authKeyDisabled bool) *ServerConnection { + authKeyDisabled bool, settings SSHSettings) *ServerConnection { dlog.Client.Debug(server, "Creating new connection", server, handler, commands) - sshConnectTimeout := time.Duration(config.Common.SSHConnectTimeoutMs) * time.Millisecond + sshConnectTimeout := defaultSSHConnectTimeout + defaultPort := defaultSSHPort + if settings != nil { + sshConnectTimeout = settings.SSHConnectTimeout() + defaultPort = settings.SSHPort() + } if sshConnectTimeout <= 0 { - sshConnectTimeout = 2 * time.Second + sshConnectTimeout = defaultSSHConnectTimeout + } + if defaultPort <= 0 { + defaultPort = defaultSSHPort } c := ServerConnection{ @@ -66,7 +84,7 @@ func NewServerConnection(server string, userName string, }, } - c.initServerPort() + c.initServerPort(defaultPort) return &c } @@ -77,11 +95,11 @@ func (c *ServerConnection) Server() string { return c.server } func (c *ServerConnection) Handler() handlers.Handler { return c.handler } // Attempt to parse the server port address from the provided server FQDN. -func (c *ServerConnection) initServerPort() { +func (c *ServerConnection) initServerPort(defaultPort int) { parts := strings.Split(c.server, ":") if len(parts) == 1 { c.hostname = c.server - c.port = config.Common.SSHPort + c.port = defaultPort return } diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go index 8ab126b..227a1e9 100644 --- a/internal/clients/connectors/serverconnection_test.go +++ b/internal/clients/connectors/serverconnection_test.go @@ -1,12 +1,16 @@ package connectors import ( + "context" "os" "path/filepath" "testing" + "time" "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/io/dlog" + + "golang.org/x/crypto/ssh" ) func TestExtractAuthKeyBase64(t *testing.T) { @@ -76,6 +80,90 @@ func TestSendAuthKeyRegistrationCommand(t *testing.T) { } } +func TestNewServerConnectionUsesInjectedSettings(t *testing.T) { + resetClientLogger(t) + + conn := NewServerConnection( + "srv1", + "user", + nil, + testHostKeyCallback{}, + &mockHandler{}, + nil, + "", + false, + testSSHSettings{port: 3022, timeout: 5 * time.Second}, + ) + + if conn.hostname != "srv1" { + t.Fatalf("Expected hostname srv1, got %q", conn.hostname) + } + if conn.port != 3022 { + t.Fatalf("Expected injected port 3022, got %d", conn.port) + } + if conn.config.Timeout != 5*time.Second { + t.Fatalf("Expected injected timeout 5s, got %v", conn.config.Timeout) + } +} + +func TestNewServerConnectionFallsBackToDefaults(t *testing.T) { + resetClientLogger(t) + + conn := NewServerConnection( + "srv1", + "user", + nil, + testHostKeyCallback{}, + &mockHandler{}, + nil, + "", + false, + testSSHSettings{}, + ) + + if conn.port != defaultSSHPort { + t.Fatalf("Expected default port %d, got %d", defaultSSHPort, conn.port) + } + if conn.config.Timeout != defaultSSHConnectTimeout { + t.Fatalf("Expected default timeout %v, got %v", defaultSSHConnectTimeout, conn.config.Timeout) + } +} + +type testSSHSettings struct { + port int + timeout time.Duration +} + +func (s testSSHSettings) SSHPort() int { + return s.port +} + +func (s testSSHSettings) SSHConnectTimeout() time.Duration { + return s.timeout +} + +type testHostKeyCallback struct{} + +func (testHostKeyCallback) Wrap() ssh.HostKeyCallback { + return ssh.InsecureIgnoreHostKey() +} + +func (testHostKeyCallback) Untrusted(string) bool { + return false +} + +func (testHostKeyCallback) PromptAddHosts(context.Context) {} + +func resetClientLogger(t *testing.T) { + t.Helper() + + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + dlog.Client = originalLogger + }) +} + type mockHandler struct { commands []string } diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go index cedf37f..daf825f 100644 --- a/internal/clients/connectors/serverless.go +++ b/internal/clients/connectors/serverless.go @@ -5,31 +5,35 @@ import ( "io" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" serverHandlers "github.com/mimecast/dtail/internal/server/handlers" - sshserver "github.com/mimecast/dtail/internal/ssh/server" - user "github.com/mimecast/dtail/internal/user/server" ) +// ServerlessHandlerFactory creates the in-process server-side handler used by serverless mode. +type ServerlessHandlerFactory interface { + NewServerlessHandler(userName string) (serverHandlers.Handler, error) +} + // Serverless creates a server object directly without TCP. type Serverless struct { - handler handlers.Handler - commands []string - userName string + handler handlers.Handler + commands []string + userName string + handlerFactory ServerlessHandlerFactory } var _ Connector = (*Serverless)(nil) // NewServerless starts a new serverless session. func NewServerless(userName string, handler handlers.Handler, - commands []string) *Serverless { + commands []string, handlerFactory ServerlessHandlerFactory) *Serverless { dlog.Client.Debug("Creating new serverless connector", handler, commands) return &Serverless{ - userName: userName, - handler: handler, - commands: commands, + userName: userName, + handler: handler, + commands: commands, + handlerFactory: handlerFactory, } } @@ -60,31 +64,14 @@ func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc, func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) error { dlog.Client.Debug("Creating server handler for a serverless session") - var permissionLookup user.PermissionLookup - if config.Server != nil { - permissionLookup = config.Server.UserPermissions + if s.handlerFactory == nil { + return io.ErrClosedPipe } - user, err := user.New(s.userName, s.Server(), permissionLookup) + serverHandler, err := s.handlerFactory.NewServerlessHandler(s.userName) if err != nil { return err } - var serverHandler serverHandlers.Handler - switch s.userName { - case config.HealthUser: - dlog.Client.Debug("Creating serverless health handler") - serverHandler = serverHandlers.NewHealthHandler(user) - default: - dlog.Client.Debug("Creating serverless server handler") - serverHandler = serverHandlers.NewServerHandler( - user, - make(chan struct{}, config.Server.MaxConcurrentCats), - make(chan struct{}, config.Server.MaxConcurrentTails), - config.Server, - sshserver.AuthKeys(), - ) - } - terminate := func() { dlog.Client.Debug("Terminating serverless connection") serverHandler.Shutdown() diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go index 7521c67..f0f08d4 100644 --- a/internal/clients/grepclient.go +++ b/internal/clients/grepclient.go @@ -30,6 +30,7 @@ func NewGrepClient(args config.Args) (*GrepClient, error) { Args: args, throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, + runtime: newClientRuntimeBoundary(config.CurrentRuntime()), }, } diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index f3ba81f..f699912 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -28,6 +28,7 @@ func NewHealthClient(args config.Args) (*HealthClient, error) { Args: args, throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, + runtime: newClientRuntimeBoundary(config.CurrentRuntime()), }, } diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go index 95b3a9c..2757229 100644 --- a/internal/clients/maprclient.go +++ b/internal/clients/maprclient.go @@ -9,7 +9,6 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/mapr" @@ -73,6 +72,7 @@ func NewMaprClient(args config.Args, maprClientMode MaprClientMode) (*MaprClient Args: args, throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: retry, + runtime: newClientRuntimeBoundary(config.CurrentRuntime()), }, query: query, cumulative: cumulative, @@ -201,9 +201,9 @@ func (c *MaprClient) printResults() error { } if c.cumulative { - result, numRows, err = c.globalGroup.Result(c.query, rowsLimit) + result, numRows, err = c.globalGroup.Result(c.query, rowsLimit, c.runtime.output.MaprResultRenderer()) } else { - result, numRows, err = c.globalGroup.SwapOut().Result(c.query, rowsLimit) + result, numRows, err = c.globalGroup.SwapOut().Result(c.query, rowsLimit, c.runtime.output.MaprResultRenderer()) } if err != nil { return fmt.Errorf("unable to render mapreduce result: %w", err) @@ -220,13 +220,7 @@ func (c *MaprClient) printResults() error { return nil } - rawQuery := c.query.RawQuery - if config.Client.TermColorsEnable { - rawQuery = color.PaintStrWithAttr(rawQuery, - config.Client.TermColors.MaprTable.RawQueryFg, - config.Client.TermColors.MaprTable.RawQueryBg, - config.Client.TermColors.MaprTable.RawQueryAttr) - } + rawQuery := c.runtime.output.PaintMaprRawQuery(c.query.RawQuery) dlog.Client.Raw(fmt.Sprintf("%s\n", rawQuery)) if rowsLimit > 0 && numRows > rowsLimit { diff --git a/internal/clients/runtime_boundary.go b/internal/clients/runtime_boundary.go new file mode 100644 index 0000000..fe58fde --- /dev/null +++ b/internal/clients/runtime_boundary.go @@ -0,0 +1,212 @@ +package clients + +import ( + "fmt" + "strings" + "time" + + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/mapr" + serverHandlers "github.com/mimecast/dtail/internal/server/handlers" + sshserver "github.com/mimecast/dtail/internal/ssh/server" + user "github.com/mimecast/dtail/internal/user/server" +) + +type clientRuntimeBoundary struct { + sshPort int + sshConnectTimeout time.Duration + interruptPause time.Duration + serverCfg *config.ServerConfig + output *clientOutputFormatter +} + +func newClientRuntimeBoundary(cfg config.RuntimeConfig) *clientRuntimeBoundary { + sshPort := 2222 + sshConnectTimeout := 2 * time.Second + if cfg.Common != nil { + if cfg.Common.SSHPort > 0 { + sshPort = cfg.Common.SSHPort + } + if cfg.Common.SSHConnectTimeoutMs > 0 { + sshConnectTimeout = time.Duration(cfg.Common.SSHConnectTimeoutMs) * time.Millisecond + } + } + + return &clientRuntimeBoundary{ + sshPort: sshPort, + sshConnectTimeout: sshConnectTimeout, + interruptPause: time.Second * time.Duration(config.InterruptTimeoutS), + serverCfg: cfg.Server, + output: newClientOutputFormatter(cfg.Client), + } +} + +func (r *clientRuntimeBoundary) SSHPort() int { + return r.sshPort +} + +func (r *clientRuntimeBoundary) SSHConnectTimeout() time.Duration { + return r.sshConnectTimeout +} + +func (r *clientRuntimeBoundary) InterruptPause() time.Duration { + if r == nil || r.interruptPause <= 0 { + return time.Second * time.Duration(config.InterruptTimeoutS) + } + return r.interruptPause +} + +func (r *clientRuntimeBoundary) NewServerlessHandler(userName string) (serverHandlers.Handler, error) { + var permissionLookup user.PermissionLookup + if r.serverCfg != nil { + permissionLookup = r.serverCfg.UserPermissions + } + + serverUser, err := user.New(userName, "local(serverless)", permissionLookup) + if err != nil { + return nil, err + } + + switch userName { + case config.HealthUser: + return serverHandlers.NewHealthHandler(serverUser), nil + default: + if r.serverCfg == nil { + return nil, fmt.Errorf("missing serverless server config") + } + return serverHandlers.NewServerHandler( + serverUser, + make(chan struct{}, positiveOrDefault(r.serverCfg.MaxConcurrentCats, 2)), + make(chan struct{}, positiveOrDefault(r.serverCfg.MaxConcurrentTails, 50)), + r.serverCfg, + sshserver.AuthKeys(), + ), nil + } +} + +func positiveOrDefault(value, fallback int) int { + if value <= 0 { + return fallback + } + return value +} + +type interruptMessageFormatter interface { + FormatInterruptMessage(index int, message string) string +} + +type clientOutputFormatter struct { + interruptEnabled bool + interruptStyle textStyle + rawQueryEnabled bool + rawQueryStyle textStyle + maprRenderer mapr.ResultRenderer +} + +func newClientOutputFormatter(clientCfg *config.ClientConfig) *clientOutputFormatter { + formatter := &clientOutputFormatter{ + maprRenderer: mapr.PlainResultRenderer(), + } + if clientCfg == nil || !clientCfg.TermColorsEnable { + return formatter + } + + formatter.interruptEnabled = true + formatter.rawQueryEnabled = true + formatter.interruptStyle = textStyle{ + fg: clientCfg.TermColors.Client.ClientFg, + bg: clientCfg.TermColors.Client.ClientBg, + attr: clientCfg.TermColors.Client.ClientAttr, + } + formatter.rawQueryStyle = textStyle{ + fg: clientCfg.TermColors.MaprTable.RawQueryFg, + bg: clientCfg.TermColors.MaprTable.RawQueryBg, + attr: clientCfg.TermColors.MaprTable.RawQueryAttr, + } + formatter.maprRenderer = maprTerminalRenderer{ + header: textStyle{ + fg: clientCfg.TermColors.MaprTable.HeaderFg, + bg: clientCfg.TermColors.MaprTable.HeaderBg, + attr: clientCfg.TermColors.MaprTable.HeaderAttr, + }, + headerDelimiter: textStyle{ + fg: clientCfg.TermColors.MaprTable.HeaderDelimiterFg, + bg: clientCfg.TermColors.MaprTable.HeaderDelimiterBg, + attr: clientCfg.TermColors.MaprTable.HeaderDelimiterAttr, + }, + headerSortAttr: clientCfg.TermColors.MaprTable.HeaderSortKeyAttr, + headerGroupAttr: clientCfg.TermColors.MaprTable.HeaderGroupKeyAttr, + data: textStyle{ + fg: clientCfg.TermColors.MaprTable.DataFg, + bg: clientCfg.TermColors.MaprTable.DataBg, + attr: clientCfg.TermColors.MaprTable.DataAttr, + }, + dataDelimiter: textStyle{ + fg: clientCfg.TermColors.MaprTable.DelimiterFg, + bg: clientCfg.TermColors.MaprTable.DelimiterBg, + attr: clientCfg.TermColors.MaprTable.DelimiterAttr, + }, + } + + return formatter +} + +func (f *clientOutputFormatter) FormatInterruptMessage(index int, message string) string { + if index > 0 && f.interruptEnabled { + return color.PaintStrWithAttr(message, f.interruptStyle.fg, f.interruptStyle.bg, f.interruptStyle.attr) + } + return " " + message +} + +func (f *clientOutputFormatter) PaintMaprRawQuery(rawQuery string) string { + if !f.rawQueryEnabled { + return rawQuery + } + return color.PaintStrWithAttr(rawQuery, f.rawQueryStyle.fg, f.rawQueryStyle.bg, f.rawQueryStyle.attr) +} + +func (f *clientOutputFormatter) MaprResultRenderer() mapr.ResultRenderer { + if f == nil || f.maprRenderer == nil { + return mapr.PlainResultRenderer() + } + return f.maprRenderer +} + +type textStyle struct { + fg color.FgColor + bg color.BgColor + attr color.Attribute +} + +type maprTerminalRenderer struct { + header textStyle + headerDelimiter textStyle + headerSortAttr color.Attribute + headerGroupAttr color.Attribute + data textStyle + dataDelimiter textStyle +} + +func (r maprTerminalRenderer) WriteHeaderEntry(sb *strings.Builder, text string, isSortKey, isGroupKey bool) { + attrs := []color.Attribute{r.header.attr} + if isSortKey { + attrs = append(attrs, r.headerSortAttr) + } + if isGroupKey { + attrs = append(attrs, r.headerGroupAttr) + } + color.PaintWithAttrs(sb, text, r.header.fg, r.header.bg, attrs) +} + +func (r maprTerminalRenderer) WriteHeaderDelimiter(sb *strings.Builder, text string) { + color.PaintWithAttr(sb, text, r.headerDelimiter.fg, r.headerDelimiter.bg, r.headerDelimiter.attr) +} + +func (r maprTerminalRenderer) WriteDataEntry(sb *strings.Builder, text string) { + color.PaintWithAttr(sb, text, r.data.fg, r.data.bg, r.data.attr) +} + +func (r maprTerminalRenderer) WriteDataDelimiter(sb *strings.Builder, text string) { + color.PaintWithAttr(sb, text, r.dataDelimiter.fg, r.dataDelimiter.bg, r.dataDelimiter.attr) +} diff --git a/internal/clients/runtime_boundary_test.go b/internal/clients/runtime_boundary_test.go new file mode 100644 index 0000000..9947865 --- /dev/null +++ b/internal/clients/runtime_boundary_test.go @@ -0,0 +1,89 @@ +package clients + +import ( + "strings" + "testing" + "time" + + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" +) + +func TestNewClientRuntimeBoundaryDefaults(t *testing.T) { + runtime := newClientRuntimeBoundary(config.RuntimeConfig{}) + + if runtime.SSHPort() != 2222 { + t.Fatalf("Expected default SSH port 2222, got %d", runtime.SSHPort()) + } + if runtime.SSHConnectTimeout() != 2*time.Second { + t.Fatalf("Expected default timeout 2s, got %v", runtime.SSHConnectTimeout()) + } + if runtime.InterruptPause() != 3*time.Second { + t.Fatalf("Expected default interrupt pause 3s, got %v", runtime.InterruptPause()) + } + if got := runtime.output.PaintMaprRawQuery("select 1"); got != "select 1" { + t.Fatalf("Expected plain raw query output, got %q", got) + } +} + +func TestNewClientRuntimeBoundaryUsesConfiguredSSHSettings(t *testing.T) { + runtime := newClientRuntimeBoundary(config.RuntimeConfig{ + Common: &config.CommonConfig{ + SSHPort: 4022, + SSHConnectTimeoutMs: 4500, + }, + }) + + if runtime.SSHPort() != 4022 { + t.Fatalf("Expected configured SSH port 4022, got %d", runtime.SSHPort()) + } + if runtime.SSHConnectTimeout() != 4500*time.Millisecond { + t.Fatalf("Expected configured timeout 4.5s, got %v", runtime.SSHConnectTimeout()) + } +} + +func TestClientOutputFormatterColorModes(t *testing.T) { + plain := newClientOutputFormatter(nil) + if got := plain.FormatInterruptMessage(1, "hello"); got != " hello" { + t.Fatalf("Expected plain interrupt message, got %q", got) + } + if got := plain.PaintMaprRawQuery("select 1"); got != "select 1" { + t.Fatalf("Expected plain raw query, got %q", got) + } + + cfg := &config.ClientConfig{TermColorsEnable: true} + cfg.TermColors.Client.ClientFg = color.FgBlack + cfg.TermColors.Client.ClientBg = color.BgYellow + cfg.TermColors.Client.ClientAttr = color.AttrBold + cfg.TermColors.MaprTable.RawQueryFg = color.FgCyan + cfg.TermColors.MaprTable.RawQueryBg = color.BgBlack + cfg.TermColors.MaprTable.RawQueryAttr = color.AttrUnderline + cfg.TermColors.MaprTable.HeaderFg = color.FgWhite + cfg.TermColors.MaprTable.HeaderBg = color.BgBlue + cfg.TermColors.MaprTable.HeaderAttr = color.AttrBold + cfg.TermColors.MaprTable.HeaderDelimiterFg = color.FgWhite + cfg.TermColors.MaprTable.HeaderDelimiterBg = color.BgBlue + cfg.TermColors.MaprTable.HeaderDelimiterAttr = color.AttrDim + cfg.TermColors.MaprTable.HeaderSortKeyAttr = color.AttrUnderline + cfg.TermColors.MaprTable.HeaderGroupKeyAttr = color.AttrReverse + cfg.TermColors.MaprTable.DataFg = color.FgWhite + cfg.TermColors.MaprTable.DataBg = color.BgBlue + cfg.TermColors.MaprTable.DataAttr = color.AttrNone + cfg.TermColors.MaprTable.DelimiterFg = color.FgWhite + cfg.TermColors.MaprTable.DelimiterBg = color.BgBlue + cfg.TermColors.MaprTable.DelimiterAttr = color.AttrDim + + colored := newClientOutputFormatter(cfg) + if got := colored.FormatInterruptMessage(0, "hello"); got != " hello" { + t.Fatalf("Expected first interrupt line to stay plain, got %q", got) + } + if got := colored.FormatInterruptMessage(1, "hello"); !strings.Contains(got, "\x1b[") { + t.Fatalf("Expected colored interrupt output, got %q", got) + } + if got := colored.PaintMaprRawQuery("select 1"); !strings.Contains(got, "\x1b[") { + t.Fatalf("Expected colored raw query output, got %q", got) + } + if colored.MaprResultRenderer() == nil { + t.Fatal("Expected non-nil mapreduce result renderer") + } +} diff --git a/internal/clients/stats.go b/internal/clients/stats.go index 5880fd1..1ce04e6 100644 --- a/internal/clients/stats.go +++ b/internal/clients/stats.go @@ -8,8 +8,6 @@ import ( "sync" "time" - "github.com/mimecast/dtail/internal/color" - "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/protocol" ) @@ -24,13 +22,22 @@ type stats struct { connected int // To synchronize concurrent access. mutex sync.Mutex + // Formats interrupt-driven stats output. + formatter interruptMessageFormatter + // Controls how long interrupt output remains visible. + interruptPause time.Duration } -func newTailStats(servers int) *stats { +func newTailStats(servers int, formatter interruptMessageFormatter, interruptPause time.Duration) *stats { + if interruptPause <= 0 { + interruptPause = 3 * time.Second + } return &stats{ servers: servers, connectionsEstCh: make(chan struct{}, servers), connected: 0, + formatter: formatter, + interruptPause: interruptPause, } } @@ -84,17 +91,13 @@ func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, func (s *stats) printStatsDueInterrupt(messages []string) { dlog.Client.Pause() for i, message := range messages { - if i > 0 && config.Client.TermColorsEnable { - fmt.Println(color.PaintStrWithAttr(message, - config.Client.TermColors.Client.ClientFg, - config.Client.TermColors.Client.ClientBg, - config.Client.TermColors.Client.ClientAttr, - )) + if s.formatter != nil { + fmt.Println(s.formatter.FormatInterruptMessage(i, message)) continue } fmt.Printf(" %s\n", message) } - time.Sleep(time.Second * time.Duration(config.InterruptTimeoutS)) + time.Sleep(s.interruptPause) dlog.Client.Resume() } diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go index 35c01d4..fff6646 100644 --- a/internal/clients/tailclient.go +++ b/internal/clients/tailclient.go @@ -24,6 +24,7 @@ func NewTailClient(args config.Args) (*TailClient, error) { Args: args, throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: true, + runtime: newClientRuntimeBoundary(config.CurrentRuntime()), }, } diff --git a/internal/mapr/globalgroupset.go b/internal/mapr/globalgroupset.go index 2b12898..4f1e5ba 100644 --- a/internal/mapr/globalgroupset.go +++ b/internal/mapr/globalgroupset.go @@ -86,8 +86,8 @@ func (g *GlobalGroupSet) WriteResult(query *Query, finalResult bool) error { } // Result returns the result of the mapreduce aggregation as a string. -func (g *GlobalGroupSet) Result(query *Query, rowsLimit int) (string, int, error) { +func (g *GlobalGroupSet) Result(query *Query, rowsLimit int, renderer ResultRenderer) (string, int, error) { g.semaphore <- struct{}{} defer func() { <-g.semaphore }() - return g.GroupSet.Result(query, rowsLimit) + return g.GroupSet.Result(query, rowsLimit, renderer) } diff --git a/internal/mapr/groupsetresult.go b/internal/mapr/groupsetresult.go index 26e4b12..c87c22f 100644 --- a/internal/mapr/groupsetresult.go +++ b/internal/mapr/groupsetresult.go @@ -6,15 +6,13 @@ import ( "os" "strings" - "github.com/mimecast/dtail/internal/color" - "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/pool" "github.com/mimecast/dtail/internal/protocol" ) // Result returns a nicely formated result of the query from the group set. -func (g *GroupSet) Result(query *Query, rowsLimit int) (string, int, error) { +func (g *GroupSet) Result(query *Query, rowsLimit int, renderer ResultRenderer) (string, int, error) { rows, columnWidths, err := g.result(query, true) if err != nil { return "", 0, err @@ -27,97 +25,69 @@ func (g *GroupSet) Result(query *Query, rowsLimit int) (string, int, error) { sb := pool.BuilderBuffer.Get().(*strings.Builder) defer pool.RecycleBuilderBuffer(sb) - g.resultWriteFormattedHeader(query, sb, lastColumn, rowsLimit, columnWidths) - g.resultWriteFormattedHeaderRowSeparator(query, sb, lastColumn, columnWidths) - g.resultWriteFormattedData(query, sb, lastColumn, rowsLimit, columnWidths, rows) + if renderer == nil { + renderer = PlainResultRenderer() + } + + g.resultWriteFormattedHeader(query, renderer, sb, lastColumn, rowsLimit, columnWidths) + g.resultWriteFormattedHeaderRowSeparator(query, renderer, sb, lastColumn, columnWidths) + g.resultWriteFormattedData(query, renderer, sb, lastColumn, rowsLimit, columnWidths, rows) return sb.String(), len(rows), nil } // Write a nicely formatted header for the result data. -func (g *GroupSet) resultWriteFormattedHeader(query *Query, sb *strings.Builder, +func (g *GroupSet) resultWriteFormattedHeader(query *Query, renderer ResultRenderer, sb *strings.Builder, lastColumn, rowsLimit int, columnWidths []int) { for i, sc := range query.Select { format := fmt.Sprintf(" %%%ds ", columnWidths[i]) str := fmt.Sprintf(format, sc.FieldStorage) - g.resultWriteFormattedHeaderEntry(query, sb, sc, str) + g.resultWriteFormattedHeaderEntry(query, renderer, sb, sc, str) if i == lastColumn { continue } - g.resultWriteFormattedHeaderEntrySeparator(query, sb) + g.resultWriteFormattedHeaderEntrySeparator(renderer, sb) } sb.WriteString("\n") } -func (g *GroupSet) resultWriteFormattedHeaderEntry(query *Query, sb *strings.Builder, +func (g *GroupSet) resultWriteFormattedHeaderEntry(query *Query, renderer ResultRenderer, sb *strings.Builder, sc selectCondition, str string) { - if config.Client.TermColorsEnable { - attrs := []color.Attribute{config.Client.TermColors.MaprTable.HeaderAttr} - if sc.FieldStorage == query.OrderBy { - attrs = append(attrs, config.Client.TermColors.MaprTable.HeaderSortKeyAttr) - } - for _, groupBy := range query.GroupBy { - if sc.FieldStorage == groupBy { - attrs = append(attrs, config.Client.TermColors.MaprTable.HeaderGroupKeyAttr) - break - } + isGroupKey := false + for _, groupBy := range query.GroupBy { + if sc.FieldStorage == groupBy { + isGroupKey = true + break } - color.PaintWithAttrs(sb, str, - config.Client.TermColors.MaprTable.HeaderFg, - config.Client.TermColors.MaprTable.HeaderBg, - attrs) - - } else { - sb.WriteString(str) } + renderer.WriteHeaderEntry(sb, str, sc.FieldStorage == query.OrderBy, isGroupKey) } -func (g *GroupSet) resultWriteFormattedHeaderEntrySeparator(query *Query, sb *strings.Builder) { - if config.Client.TermColorsEnable { - color.PaintWithAttr(sb, protocol.FieldDelimiter, - config.Client.TermColors.MaprTable.HeaderDelimiterFg, - config.Client.TermColors.MaprTable.HeaderDelimiterBg, - config.Client.TermColors.MaprTable.HeaderDelimiterAttr) - } else { - sb.WriteString(protocol.FieldDelimiter) - } +func (g *GroupSet) resultWriteFormattedHeaderEntrySeparator(renderer ResultRenderer, sb *strings.Builder) { + renderer.WriteHeaderDelimiter(sb, protocol.FieldDelimiter) } // This writes a nicely formatted line separating the header and the data. -func (g *GroupSet) resultWriteFormattedHeaderRowSeparator(query *Query, sb *strings.Builder, +func (g *GroupSet) resultWriteFormattedHeaderRowSeparator(query *Query, renderer ResultRenderer, sb *strings.Builder, lastColumn int, columnWidths []int) { for i := 0; i < len(query.Select); i++ { str := fmt.Sprintf("-%s-", strings.Repeat("-", columnWidths[i])) - if config.Client.TermColorsEnable { - color.PaintWithAttr(sb, str, - config.Client.TermColors.MaprTable.HeaderDelimiterFg, - config.Client.TermColors.MaprTable.HeaderDelimiterBg, - config.Client.TermColors.MaprTable.HeaderDelimiterAttr) - } else { - sb.WriteString(str) - } + renderer.WriteHeaderDelimiter(sb, str) if i == lastColumn { continue } - if config.Client.TermColorsEnable { - color.PaintWithAttr(sb, protocol.FieldDelimiter, - config.Client.TermColors.MaprTable.HeaderDelimiterFg, - config.Client.TermColors.MaprTable.HeaderDelimiterBg, - config.Client.TermColors.MaprTable.HeaderDelimiterAttr) - } else { - sb.WriteString(protocol.FieldDelimiter) - } + renderer.WriteHeaderDelimiter(sb, protocol.FieldDelimiter) } sb.WriteString("\n") } // Write the result data nicely formatted. -func (g *GroupSet) resultWriteFormattedData(query *Query, sb *strings.Builder, +func (g *GroupSet) resultWriteFormattedData(query *Query, renderer ResultRenderer, sb *strings.Builder, lastColumn, rowsLimit int, columnWidths []int, rows []result) { for i, r := range rows { @@ -125,37 +95,22 @@ func (g *GroupSet) resultWriteFormattedData(query *Query, sb *strings.Builder, break } for j, value := range r.values { - g.resultWriteFormattedDataEntry(query, sb, columnWidths, j, value) + g.resultWriteFormattedDataEntry(renderer, sb, columnWidths, j, value) if j == lastColumn { continue } - // Now, write the data entry separator. - if config.Client.TermColorsEnable { - color.PaintWithAttr(sb, protocol.FieldDelimiter, - config.Client.TermColors.MaprTable.DelimiterFg, - config.Client.TermColors.MaprTable.DelimiterBg, - config.Client.TermColors.MaprTable.DelimiterAttr) - } else { - sb.WriteString(protocol.FieldDelimiter) - } + renderer.WriteDataDelimiter(sb, protocol.FieldDelimiter) } sb.WriteString("\n") } } -func (g *GroupSet) resultWriteFormattedDataEntry(query *Query, sb *strings.Builder, +func (g *GroupSet) resultWriteFormattedDataEntry(renderer ResultRenderer, sb *strings.Builder, columnWidths []int, j int, value string) { format := fmt.Sprintf(" %%%ds ", columnWidths[j]) str := fmt.Sprintf(format, value) - if config.Client.TermColorsEnable { - color.PaintWithAttr(sb, str, - config.Client.TermColors.MaprTable.DataFg, - config.Client.TermColors.MaprTable.DataBg, - config.Client.TermColors.MaprTable.DataAttr) - } else { - sb.WriteString(str) - } + renderer.WriteDataEntry(sb, str) } func (*GroupSet) writeQueryFile(query *Query) error { diff --git a/internal/mapr/groupsetresult_renderer_test.go b/internal/mapr/groupsetresult_renderer_test.go new file mode 100644 index 0000000..53f45d5 --- /dev/null +++ b/internal/mapr/groupsetresult_renderer_test.go @@ -0,0 +1,113 @@ +package mapr + +import ( + "strings" + "testing" +) + +func TestGroupSetResultUsesProvidedRenderer(t *testing.T) { + query, err := NewQuery("select host,count(value) from stats group by host order by count(value)") + if err != nil { + t.Fatalf("Unable to parse query: %v", err) + } + + groupSet := NewGroupSet() + set := groupSet.GetSet("host-a") + if err := set.Aggregate("host", Last, "host-a", false); err != nil { + t.Fatalf("Unable to aggregate host field: %v", err) + } + if err := set.Aggregate("count(value)", Count, "", false); err != nil { + t.Fatalf("Unable to aggregate count field: %v", err) + } + + renderer := &recordingRenderer{} + result, numRows, err := groupSet.Result(query, 10, renderer) + if err != nil { + t.Fatalf("Unable to render result: %v", err) + } + if numRows != 1 { + t.Fatalf("Expected one row, got %d", numRows) + } + if len(renderer.headerCalls) != 2 { + t.Fatalf("Expected two header calls, got %d", len(renderer.headerCalls)) + } + if renderer.headerCalls[0].isSortKey || !renderer.headerCalls[0].isGroupKey { + t.Fatalf("Unexpected flags for group key header: %+v", renderer.headerCalls[0]) + } + if !renderer.headerCalls[1].isSortKey || renderer.headerCalls[1].isGroupKey { + t.Fatalf("Unexpected flags for sort key header: %+v", renderer.headerCalls[1]) + } + if len(renderer.headerDelimiters) == 0 { + t.Fatal("Expected header delimiters to be rendered") + } + if len(renderer.dataDelimiters) == 0 { + t.Fatal("Expected data delimiters to be rendered") + } + if !strings.Contains(result, "host-a") || !strings.Contains(result, "1") { + t.Fatalf("Expected rendered output to contain row data, got %q", result) + } +} + +func TestGroupSetResultFallsBackToPlainRenderer(t *testing.T) { + query, err := NewQuery("select count(value) from stats") + if err != nil { + t.Fatalf("Unable to parse query: %v", err) + } + + groupSet := NewGroupSet() + set := groupSet.GetSet("") + if err := set.Aggregate("count(value)", Count, "", false); err != nil { + t.Fatalf("Unable to aggregate count field: %v", err) + } + + result, numRows, err := groupSet.Result(query, 10, nil) + if err != nil { + t.Fatalf("Unable to render result with nil renderer: %v", err) + } + if numRows != 1 { + t.Fatalf("Expected one row, got %d", numRows) + } + if !strings.Contains(result, "count(value)") || !strings.Contains(result, "1") { + t.Fatalf("Expected plain rendered output, got %q", result) + } + if strings.Contains(result, "\x1b[") { + t.Fatalf("Expected plain output without ANSI escapes, got %q", result) + } +} + +type recordingRenderer struct { + headerCalls []headerCall + headerDelimiters []string + dataEntries []string + dataDelimiters []string +} + +type headerCall struct { + text string + isSortKey bool + isGroupKey bool +} + +func (r *recordingRenderer) WriteHeaderEntry(sb *strings.Builder, text string, isSortKey, isGroupKey bool) { + r.headerCalls = append(r.headerCalls, headerCall{ + text: text, + isSortKey: isSortKey, + isGroupKey: isGroupKey, + }) + sb.WriteString(text) +} + +func (r *recordingRenderer) WriteHeaderDelimiter(sb *strings.Builder, text string) { + r.headerDelimiters = append(r.headerDelimiters, text) + sb.WriteString(text) +} + +func (r *recordingRenderer) WriteDataEntry(sb *strings.Builder, text string) { + r.dataEntries = append(r.dataEntries, text) + sb.WriteString(text) +} + +func (r *recordingRenderer) WriteDataDelimiter(sb *strings.Builder, text string) { + r.dataDelimiters = append(r.dataDelimiters, text) + sb.WriteString(text) +} diff --git a/internal/mapr/result_renderer.go b/internal/mapr/result_renderer.go new file mode 100644 index 0000000..e03ed9a --- /dev/null +++ b/internal/mapr/result_renderer.go @@ -0,0 +1,34 @@ +package mapr + +import "strings" + +// ResultRenderer formats terminal table output for mapreduce results. +type ResultRenderer interface { + WriteHeaderEntry(sb *strings.Builder, text string, isSortKey, isGroupKey bool) + WriteHeaderDelimiter(sb *strings.Builder, text string) + WriteDataEntry(sb *strings.Builder, text string) + WriteDataDelimiter(sb *strings.Builder, text string) +} + +type plainResultRenderer struct{} + +// PlainResultRenderer returns a renderer that writes uncolored terminal output. +func PlainResultRenderer() ResultRenderer { + return plainResultRenderer{} +} + +func (plainResultRenderer) WriteHeaderEntry(sb *strings.Builder, text string, _, _ bool) { + sb.WriteString(text) +} + +func (plainResultRenderer) WriteHeaderDelimiter(sb *strings.Builder, text string) { + sb.WriteString(text) +} + +func (plainResultRenderer) WriteDataEntry(sb *strings.Builder, text string) { + sb.WriteString(text) +} + +func (plainResultRenderer) WriteDataDelimiter(sb *strings.Builder, text string) { + sb.WriteString(text) +} diff --git a/internal/version/version.go b/internal/version/version.go index 05bd28f..bffdccd 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -5,7 +5,6 @@ import ( "os" "github.com/mimecast/dtail/internal/color" - "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/protocol" ) @@ -25,8 +24,8 @@ func String() string { } // PaintedString is a prettier string representation of the DTail version. -func PaintedString() string { - if !config.Client.TermColorsEnable { +func PaintedString(colorsEnabled bool) string { + if !colorsEnabled { return String() } @@ -43,12 +42,12 @@ func PaintedString() string { } // Print the version. -func Print() { - fmt.Println(PaintedString()) +func Print(colorsEnabled bool) { + fmt.Println(PaintedString(colorsEnabled)) } // PrintAndExit prints the program version and exists. -func PrintAndExit() { - Print() +func PrintAndExit(colorsEnabled bool) { + Print(colorsEnabled) os.Exit(0) } |
