summaryrefslogtreecommitdiff
path: root/internal/clients
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients')
-rw-r--r--internal/clients/baseclient.go10
-rw-r--r--internal/clients/catclient.go1
-rw-r--r--internal/clients/connectors/serverconnection.go32
-rw-r--r--internal/clients/connectors/serverconnection_test.go88
-rw-r--r--internal/clients/connectors/serverless.go47
-rw-r--r--internal/clients/grepclient.go1
-rw-r--r--internal/clients/healthclient.go1
-rw-r--r--internal/clients/maprclient.go14
-rw-r--r--internal/clients/runtime_boundary.go212
-rw-r--r--internal/clients/runtime_boundary_test.go89
-rw-r--r--internal/clients/stats.go23
-rw-r--r--internal/clients/tailclient.go1
12 files changed, 459 insertions, 60 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index adc2b77..7358cb3 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -25,6 +25,7 @@ const (
// This is the main client data structure.
type baseClient struct {
config.Args
+ runtime *clientRuntimeBoundary
// To display client side stats
stats *stats
// We have one connection per remote server.
@@ -45,6 +46,9 @@ type baseClient struct {
func (c *baseClient) init() {
dlog.Client.Debug("Initiating base client", c.Args.String())
+ if c.runtime == nil {
+ c.runtime = newClientRuntimeBoundary(config.CurrentRuntime())
+ }
flag := regex.Default
if c.Args.RegexInvert {
@@ -73,7 +77,7 @@ func (c *baseClient) makeConnections(maker maker) {
c.sshAuthMethods, c.hostKeyCallback))
}
- c.stats = newTailStats(len(c.connections))
+ c.stats = newTailStats(len(c.connections), c.runtime.output, c.runtime.InterruptPause())
}
func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status int) {
@@ -200,9 +204,9 @@ func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMe
hostKeyCallback client.HostKeyCallback) connectors.Connector {
if c.Args.Serverless {
return connectors.NewServerless(c.UserName, c.maker.makeHandler(server),
- c.maker.makeCommands())
+ c.maker.makeCommands(), c.runtime)
}
return connectors.NewServerConnection(server, c.UserName, sshAuthMethods,
hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands(),
- c.Args.SSHPrivateKeyFilePath, c.Args.NoAuthKey)
+ c.Args.SSHPrivateKeyFilePath, c.Args.NoAuthKey, c.runtime)
}
diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go
index bd65560..0a62b9d 100644
--- a/internal/clients/catclient.go
+++ b/internal/clients/catclient.go
@@ -29,6 +29,7 @@ func NewCatClient(args config.Args) (*CatClient, error) {
Args: args,
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: false,
+ runtime: newClientRuntimeBoundary(config.CurrentRuntime()),
},
}
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 649fe30..1136bf9 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -12,13 +12,23 @@ import (
"time"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/ssh/client"
"golang.org/x/crypto/ssh"
)
+// SSHSettings provides the connection settings needed by ServerConnection.
+type SSHSettings interface {
+ SSHPort() int
+ SSHConnectTimeout() time.Duration
+}
+
+const (
+ defaultSSHConnectTimeout = 2 * time.Second
+ defaultSSHPort = 2222
+)
+
// ServerConnection represents a connection to a single remote dtail server via
// SSH protocol.
type ServerConnection struct {
@@ -43,12 +53,20 @@ var _ Connector = (*ServerConnection)(nil)
func NewServerConnection(server string, userName string,
authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback,
handler handlers.Handler, commands []string, authKeyPath string,
- authKeyDisabled bool) *ServerConnection {
+ authKeyDisabled bool, settings SSHSettings) *ServerConnection {
dlog.Client.Debug(server, "Creating new connection", server, handler, commands)
- sshConnectTimeout := time.Duration(config.Common.SSHConnectTimeoutMs) * time.Millisecond
+ sshConnectTimeout := defaultSSHConnectTimeout
+ defaultPort := defaultSSHPort
+ if settings != nil {
+ sshConnectTimeout = settings.SSHConnectTimeout()
+ defaultPort = settings.SSHPort()
+ }
if sshConnectTimeout <= 0 {
- sshConnectTimeout = 2 * time.Second
+ sshConnectTimeout = defaultSSHConnectTimeout
+ }
+ if defaultPort <= 0 {
+ defaultPort = defaultSSHPort
}
c := ServerConnection{
@@ -66,7 +84,7 @@ func NewServerConnection(server string, userName string,
},
}
- c.initServerPort()
+ c.initServerPort(defaultPort)
return &c
}
@@ -77,11 +95,11 @@ func (c *ServerConnection) Server() string { return c.server }
func (c *ServerConnection) Handler() handlers.Handler { return c.handler }
// Attempt to parse the server port address from the provided server FQDN.
-func (c *ServerConnection) initServerPort() {
+func (c *ServerConnection) initServerPort(defaultPort int) {
parts := strings.Split(c.server, ":")
if len(parts) == 1 {
c.hostname = c.server
- c.port = config.Common.SSHPort
+ c.port = defaultPort
return
}
diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go
index 8ab126b..227a1e9 100644
--- a/internal/clients/connectors/serverconnection_test.go
+++ b/internal/clients/connectors/serverconnection_test.go
@@ -1,12 +1,16 @@
package connectors
import (
+ "context"
"os"
"path/filepath"
"testing"
+ "time"
"github.com/mimecast/dtail/internal/clients/handlers"
"github.com/mimecast/dtail/internal/io/dlog"
+
+ "golang.org/x/crypto/ssh"
)
func TestExtractAuthKeyBase64(t *testing.T) {
@@ -76,6 +80,90 @@ func TestSendAuthKeyRegistrationCommand(t *testing.T) {
}
}
+func TestNewServerConnectionUsesInjectedSettings(t *testing.T) {
+ resetClientLogger(t)
+
+ conn := NewServerConnection(
+ "srv1",
+ "user",
+ nil,
+ testHostKeyCallback{},
+ &mockHandler{},
+ nil,
+ "",
+ false,
+ testSSHSettings{port: 3022, timeout: 5 * time.Second},
+ )
+
+ if conn.hostname != "srv1" {
+ t.Fatalf("Expected hostname srv1, got %q", conn.hostname)
+ }
+ if conn.port != 3022 {
+ t.Fatalf("Expected injected port 3022, got %d", conn.port)
+ }
+ if conn.config.Timeout != 5*time.Second {
+ t.Fatalf("Expected injected timeout 5s, got %v", conn.config.Timeout)
+ }
+}
+
+func TestNewServerConnectionFallsBackToDefaults(t *testing.T) {
+ resetClientLogger(t)
+
+ conn := NewServerConnection(
+ "srv1",
+ "user",
+ nil,
+ testHostKeyCallback{},
+ &mockHandler{},
+ nil,
+ "",
+ false,
+ testSSHSettings{},
+ )
+
+ if conn.port != defaultSSHPort {
+ t.Fatalf("Expected default port %d, got %d", defaultSSHPort, conn.port)
+ }
+ if conn.config.Timeout != defaultSSHConnectTimeout {
+ t.Fatalf("Expected default timeout %v, got %v", defaultSSHConnectTimeout, conn.config.Timeout)
+ }
+}
+
+type testSSHSettings struct {
+ port int
+ timeout time.Duration
+}
+
+func (s testSSHSettings) SSHPort() int {
+ return s.port
+}
+
+func (s testSSHSettings) SSHConnectTimeout() time.Duration {
+ return s.timeout
+}
+
+type testHostKeyCallback struct{}
+
+func (testHostKeyCallback) Wrap() ssh.HostKeyCallback {
+ return ssh.InsecureIgnoreHostKey()
+}
+
+func (testHostKeyCallback) Untrusted(string) bool {
+ return false
+}
+
+func (testHostKeyCallback) PromptAddHosts(context.Context) {}
+
+func resetClientLogger(t *testing.T) {
+ t.Helper()
+
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ dlog.Client = originalLogger
+ })
+}
+
type mockHandler struct {
commands []string
}
diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go
index cedf37f..daf825f 100644
--- a/internal/clients/connectors/serverless.go
+++ b/internal/clients/connectors/serverless.go
@@ -5,31 +5,35 @@ import (
"io"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/dlog"
serverHandlers "github.com/mimecast/dtail/internal/server/handlers"
- sshserver "github.com/mimecast/dtail/internal/ssh/server"
- user "github.com/mimecast/dtail/internal/user/server"
)
+// ServerlessHandlerFactory creates the in-process server-side handler used by serverless mode.
+type ServerlessHandlerFactory interface {
+ NewServerlessHandler(userName string) (serverHandlers.Handler, error)
+}
+
// Serverless creates a server object directly without TCP.
type Serverless struct {
- handler handlers.Handler
- commands []string
- userName string
+ handler handlers.Handler
+ commands []string
+ userName string
+ handlerFactory ServerlessHandlerFactory
}
var _ Connector = (*Serverless)(nil)
// NewServerless starts a new serverless session.
func NewServerless(userName string, handler handlers.Handler,
- commands []string) *Serverless {
+ commands []string, handlerFactory ServerlessHandlerFactory) *Serverless {
dlog.Client.Debug("Creating new serverless connector", handler, commands)
return &Serverless{
- userName: userName,
- handler: handler,
- commands: commands,
+ userName: userName,
+ handler: handler,
+ commands: commands,
+ handlerFactory: handlerFactory,
}
}
@@ -60,31 +64,14 @@ func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc,
func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) error {
dlog.Client.Debug("Creating server handler for a serverless session")
- var permissionLookup user.PermissionLookup
- if config.Server != nil {
- permissionLookup = config.Server.UserPermissions
+ if s.handlerFactory == nil {
+ return io.ErrClosedPipe
}
- user, err := user.New(s.userName, s.Server(), permissionLookup)
+ serverHandler, err := s.handlerFactory.NewServerlessHandler(s.userName)
if err != nil {
return err
}
- var serverHandler serverHandlers.Handler
- switch s.userName {
- case config.HealthUser:
- dlog.Client.Debug("Creating serverless health handler")
- serverHandler = serverHandlers.NewHealthHandler(user)
- default:
- dlog.Client.Debug("Creating serverless server handler")
- serverHandler = serverHandlers.NewServerHandler(
- user,
- make(chan struct{}, config.Server.MaxConcurrentCats),
- make(chan struct{}, config.Server.MaxConcurrentTails),
- config.Server,
- sshserver.AuthKeys(),
- )
- }
-
terminate := func() {
dlog.Client.Debug("Terminating serverless connection")
serverHandler.Shutdown()
diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go
index 7521c67..f0f08d4 100644
--- a/internal/clients/grepclient.go
+++ b/internal/clients/grepclient.go
@@ -30,6 +30,7 @@ func NewGrepClient(args config.Args) (*GrepClient, error) {
Args: args,
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: false,
+ runtime: newClientRuntimeBoundary(config.CurrentRuntime()),
},
}
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
index f3ba81f..f699912 100644
--- a/internal/clients/healthclient.go
+++ b/internal/clients/healthclient.go
@@ -28,6 +28,7 @@ func NewHealthClient(args config.Args) (*HealthClient, error) {
Args: args,
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: false,
+ runtime: newClientRuntimeBoundary(config.CurrentRuntime()),
},
}
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
index 95b3a9c..2757229 100644
--- a/internal/clients/maprclient.go
+++ b/internal/clients/maprclient.go
@@ -9,7 +9,6 @@ import (
"time"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/color"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/mapr"
@@ -73,6 +72,7 @@ func NewMaprClient(args config.Args, maprClientMode MaprClientMode) (*MaprClient
Args: args,
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: retry,
+ runtime: newClientRuntimeBoundary(config.CurrentRuntime()),
},
query: query,
cumulative: cumulative,
@@ -201,9 +201,9 @@ func (c *MaprClient) printResults() error {
}
if c.cumulative {
- result, numRows, err = c.globalGroup.Result(c.query, rowsLimit)
+ result, numRows, err = c.globalGroup.Result(c.query, rowsLimit, c.runtime.output.MaprResultRenderer())
} else {
- result, numRows, err = c.globalGroup.SwapOut().Result(c.query, rowsLimit)
+ result, numRows, err = c.globalGroup.SwapOut().Result(c.query, rowsLimit, c.runtime.output.MaprResultRenderer())
}
if err != nil {
return fmt.Errorf("unable to render mapreduce result: %w", err)
@@ -220,13 +220,7 @@ func (c *MaprClient) printResults() error {
return nil
}
- rawQuery := c.query.RawQuery
- if config.Client.TermColorsEnable {
- rawQuery = color.PaintStrWithAttr(rawQuery,
- config.Client.TermColors.MaprTable.RawQueryFg,
- config.Client.TermColors.MaprTable.RawQueryBg,
- config.Client.TermColors.MaprTable.RawQueryAttr)
- }
+ rawQuery := c.runtime.output.PaintMaprRawQuery(c.query.RawQuery)
dlog.Client.Raw(fmt.Sprintf("%s\n", rawQuery))
if rowsLimit > 0 && numRows > rowsLimit {
diff --git a/internal/clients/runtime_boundary.go b/internal/clients/runtime_boundary.go
new file mode 100644
index 0000000..fe58fde
--- /dev/null
+++ b/internal/clients/runtime_boundary.go
@@ -0,0 +1,212 @@
+package clients
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/mimecast/dtail/internal/color"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/mapr"
+ serverHandlers "github.com/mimecast/dtail/internal/server/handlers"
+ sshserver "github.com/mimecast/dtail/internal/ssh/server"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+type clientRuntimeBoundary struct {
+ sshPort int
+ sshConnectTimeout time.Duration
+ interruptPause time.Duration
+ serverCfg *config.ServerConfig
+ output *clientOutputFormatter
+}
+
+func newClientRuntimeBoundary(cfg config.RuntimeConfig) *clientRuntimeBoundary {
+ sshPort := 2222
+ sshConnectTimeout := 2 * time.Second
+ if cfg.Common != nil {
+ if cfg.Common.SSHPort > 0 {
+ sshPort = cfg.Common.SSHPort
+ }
+ if cfg.Common.SSHConnectTimeoutMs > 0 {
+ sshConnectTimeout = time.Duration(cfg.Common.SSHConnectTimeoutMs) * time.Millisecond
+ }
+ }
+
+ return &clientRuntimeBoundary{
+ sshPort: sshPort,
+ sshConnectTimeout: sshConnectTimeout,
+ interruptPause: time.Second * time.Duration(config.InterruptTimeoutS),
+ serverCfg: cfg.Server,
+ output: newClientOutputFormatter(cfg.Client),
+ }
+}
+
+func (r *clientRuntimeBoundary) SSHPort() int {
+ return r.sshPort
+}
+
+func (r *clientRuntimeBoundary) SSHConnectTimeout() time.Duration {
+ return r.sshConnectTimeout
+}
+
+func (r *clientRuntimeBoundary) InterruptPause() time.Duration {
+ if r == nil || r.interruptPause <= 0 {
+ return time.Second * time.Duration(config.InterruptTimeoutS)
+ }
+ return r.interruptPause
+}
+
+func (r *clientRuntimeBoundary) NewServerlessHandler(userName string) (serverHandlers.Handler, error) {
+ var permissionLookup user.PermissionLookup
+ if r.serverCfg != nil {
+ permissionLookup = r.serverCfg.UserPermissions
+ }
+
+ serverUser, err := user.New(userName, "local(serverless)", permissionLookup)
+ if err != nil {
+ return nil, err
+ }
+
+ switch userName {
+ case config.HealthUser:
+ return serverHandlers.NewHealthHandler(serverUser), nil
+ default:
+ if r.serverCfg == nil {
+ return nil, fmt.Errorf("missing serverless server config")
+ }
+ return serverHandlers.NewServerHandler(
+ serverUser,
+ make(chan struct{}, positiveOrDefault(r.serverCfg.MaxConcurrentCats, 2)),
+ make(chan struct{}, positiveOrDefault(r.serverCfg.MaxConcurrentTails, 50)),
+ r.serverCfg,
+ sshserver.AuthKeys(),
+ ), nil
+ }
+}
+
+func positiveOrDefault(value, fallback int) int {
+ if value <= 0 {
+ return fallback
+ }
+ return value
+}
+
+type interruptMessageFormatter interface {
+ FormatInterruptMessage(index int, message string) string
+}
+
+type clientOutputFormatter struct {
+ interruptEnabled bool
+ interruptStyle textStyle
+ rawQueryEnabled bool
+ rawQueryStyle textStyle
+ maprRenderer mapr.ResultRenderer
+}
+
+func newClientOutputFormatter(clientCfg *config.ClientConfig) *clientOutputFormatter {
+ formatter := &clientOutputFormatter{
+ maprRenderer: mapr.PlainResultRenderer(),
+ }
+ if clientCfg == nil || !clientCfg.TermColorsEnable {
+ return formatter
+ }
+
+ formatter.interruptEnabled = true
+ formatter.rawQueryEnabled = true
+ formatter.interruptStyle = textStyle{
+ fg: clientCfg.TermColors.Client.ClientFg,
+ bg: clientCfg.TermColors.Client.ClientBg,
+ attr: clientCfg.TermColors.Client.ClientAttr,
+ }
+ formatter.rawQueryStyle = textStyle{
+ fg: clientCfg.TermColors.MaprTable.RawQueryFg,
+ bg: clientCfg.TermColors.MaprTable.RawQueryBg,
+ attr: clientCfg.TermColors.MaprTable.RawQueryAttr,
+ }
+ formatter.maprRenderer = maprTerminalRenderer{
+ header: textStyle{
+ fg: clientCfg.TermColors.MaprTable.HeaderFg,
+ bg: clientCfg.TermColors.MaprTable.HeaderBg,
+ attr: clientCfg.TermColors.MaprTable.HeaderAttr,
+ },
+ headerDelimiter: textStyle{
+ fg: clientCfg.TermColors.MaprTable.HeaderDelimiterFg,
+ bg: clientCfg.TermColors.MaprTable.HeaderDelimiterBg,
+ attr: clientCfg.TermColors.MaprTable.HeaderDelimiterAttr,
+ },
+ headerSortAttr: clientCfg.TermColors.MaprTable.HeaderSortKeyAttr,
+ headerGroupAttr: clientCfg.TermColors.MaprTable.HeaderGroupKeyAttr,
+ data: textStyle{
+ fg: clientCfg.TermColors.MaprTable.DataFg,
+ bg: clientCfg.TermColors.MaprTable.DataBg,
+ attr: clientCfg.TermColors.MaprTable.DataAttr,
+ },
+ dataDelimiter: textStyle{
+ fg: clientCfg.TermColors.MaprTable.DelimiterFg,
+ bg: clientCfg.TermColors.MaprTable.DelimiterBg,
+ attr: clientCfg.TermColors.MaprTable.DelimiterAttr,
+ },
+ }
+
+ return formatter
+}
+
+func (f *clientOutputFormatter) FormatInterruptMessage(index int, message string) string {
+ if index > 0 && f.interruptEnabled {
+ return color.PaintStrWithAttr(message, f.interruptStyle.fg, f.interruptStyle.bg, f.interruptStyle.attr)
+ }
+ return " " + message
+}
+
+func (f *clientOutputFormatter) PaintMaprRawQuery(rawQuery string) string {
+ if !f.rawQueryEnabled {
+ return rawQuery
+ }
+ return color.PaintStrWithAttr(rawQuery, f.rawQueryStyle.fg, f.rawQueryStyle.bg, f.rawQueryStyle.attr)
+}
+
+func (f *clientOutputFormatter) MaprResultRenderer() mapr.ResultRenderer {
+ if f == nil || f.maprRenderer == nil {
+ return mapr.PlainResultRenderer()
+ }
+ return f.maprRenderer
+}
+
+type textStyle struct {
+ fg color.FgColor
+ bg color.BgColor
+ attr color.Attribute
+}
+
+type maprTerminalRenderer struct {
+ header textStyle
+ headerDelimiter textStyle
+ headerSortAttr color.Attribute
+ headerGroupAttr color.Attribute
+ data textStyle
+ dataDelimiter textStyle
+}
+
+func (r maprTerminalRenderer) WriteHeaderEntry(sb *strings.Builder, text string, isSortKey, isGroupKey bool) {
+ attrs := []color.Attribute{r.header.attr}
+ if isSortKey {
+ attrs = append(attrs, r.headerSortAttr)
+ }
+ if isGroupKey {
+ attrs = append(attrs, r.headerGroupAttr)
+ }
+ color.PaintWithAttrs(sb, text, r.header.fg, r.header.bg, attrs)
+}
+
+func (r maprTerminalRenderer) WriteHeaderDelimiter(sb *strings.Builder, text string) {
+ color.PaintWithAttr(sb, text, r.headerDelimiter.fg, r.headerDelimiter.bg, r.headerDelimiter.attr)
+}
+
+func (r maprTerminalRenderer) WriteDataEntry(sb *strings.Builder, text string) {
+ color.PaintWithAttr(sb, text, r.data.fg, r.data.bg, r.data.attr)
+}
+
+func (r maprTerminalRenderer) WriteDataDelimiter(sb *strings.Builder, text string) {
+ color.PaintWithAttr(sb, text, r.dataDelimiter.fg, r.dataDelimiter.bg, r.dataDelimiter.attr)
+}
diff --git a/internal/clients/runtime_boundary_test.go b/internal/clients/runtime_boundary_test.go
new file mode 100644
index 0000000..9947865
--- /dev/null
+++ b/internal/clients/runtime_boundary_test.go
@@ -0,0 +1,89 @@
+package clients
+
+import (
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/mimecast/dtail/internal/color"
+ "github.com/mimecast/dtail/internal/config"
+)
+
+func TestNewClientRuntimeBoundaryDefaults(t *testing.T) {
+ runtime := newClientRuntimeBoundary(config.RuntimeConfig{})
+
+ if runtime.SSHPort() != 2222 {
+ t.Fatalf("Expected default SSH port 2222, got %d", runtime.SSHPort())
+ }
+ if runtime.SSHConnectTimeout() != 2*time.Second {
+ t.Fatalf("Expected default timeout 2s, got %v", runtime.SSHConnectTimeout())
+ }
+ if runtime.InterruptPause() != 3*time.Second {
+ t.Fatalf("Expected default interrupt pause 3s, got %v", runtime.InterruptPause())
+ }
+ if got := runtime.output.PaintMaprRawQuery("select 1"); got != "select 1" {
+ t.Fatalf("Expected plain raw query output, got %q", got)
+ }
+}
+
+func TestNewClientRuntimeBoundaryUsesConfiguredSSHSettings(t *testing.T) {
+ runtime := newClientRuntimeBoundary(config.RuntimeConfig{
+ Common: &config.CommonConfig{
+ SSHPort: 4022,
+ SSHConnectTimeoutMs: 4500,
+ },
+ })
+
+ if runtime.SSHPort() != 4022 {
+ t.Fatalf("Expected configured SSH port 4022, got %d", runtime.SSHPort())
+ }
+ if runtime.SSHConnectTimeout() != 4500*time.Millisecond {
+ t.Fatalf("Expected configured timeout 4.5s, got %v", runtime.SSHConnectTimeout())
+ }
+}
+
+func TestClientOutputFormatterColorModes(t *testing.T) {
+ plain := newClientOutputFormatter(nil)
+ if got := plain.FormatInterruptMessage(1, "hello"); got != " hello" {
+ t.Fatalf("Expected plain interrupt message, got %q", got)
+ }
+ if got := plain.PaintMaprRawQuery("select 1"); got != "select 1" {
+ t.Fatalf("Expected plain raw query, got %q", got)
+ }
+
+ cfg := &config.ClientConfig{TermColorsEnable: true}
+ cfg.TermColors.Client.ClientFg = color.FgBlack
+ cfg.TermColors.Client.ClientBg = color.BgYellow
+ cfg.TermColors.Client.ClientAttr = color.AttrBold
+ cfg.TermColors.MaprTable.RawQueryFg = color.FgCyan
+ cfg.TermColors.MaprTable.RawQueryBg = color.BgBlack
+ cfg.TermColors.MaprTable.RawQueryAttr = color.AttrUnderline
+ cfg.TermColors.MaprTable.HeaderFg = color.FgWhite
+ cfg.TermColors.MaprTable.HeaderBg = color.BgBlue
+ cfg.TermColors.MaprTable.HeaderAttr = color.AttrBold
+ cfg.TermColors.MaprTable.HeaderDelimiterFg = color.FgWhite
+ cfg.TermColors.MaprTable.HeaderDelimiterBg = color.BgBlue
+ cfg.TermColors.MaprTable.HeaderDelimiterAttr = color.AttrDim
+ cfg.TermColors.MaprTable.HeaderSortKeyAttr = color.AttrUnderline
+ cfg.TermColors.MaprTable.HeaderGroupKeyAttr = color.AttrReverse
+ cfg.TermColors.MaprTable.DataFg = color.FgWhite
+ cfg.TermColors.MaprTable.DataBg = color.BgBlue
+ cfg.TermColors.MaprTable.DataAttr = color.AttrNone
+ cfg.TermColors.MaprTable.DelimiterFg = color.FgWhite
+ cfg.TermColors.MaprTable.DelimiterBg = color.BgBlue
+ cfg.TermColors.MaprTable.DelimiterAttr = color.AttrDim
+
+ colored := newClientOutputFormatter(cfg)
+ if got := colored.FormatInterruptMessage(0, "hello"); got != " hello" {
+ t.Fatalf("Expected first interrupt line to stay plain, got %q", got)
+ }
+ if got := colored.FormatInterruptMessage(1, "hello"); !strings.Contains(got, "\x1b[") {
+ t.Fatalf("Expected colored interrupt output, got %q", got)
+ }
+ if got := colored.PaintMaprRawQuery("select 1"); !strings.Contains(got, "\x1b[") {
+ t.Fatalf("Expected colored raw query output, got %q", got)
+ }
+ if colored.MaprResultRenderer() == nil {
+ t.Fatal("Expected non-nil mapreduce result renderer")
+ }
+}
diff --git a/internal/clients/stats.go b/internal/clients/stats.go
index 5880fd1..1ce04e6 100644
--- a/internal/clients/stats.go
+++ b/internal/clients/stats.go
@@ -8,8 +8,6 @@ import (
"sync"
"time"
- "github.com/mimecast/dtail/internal/color"
- "github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/protocol"
)
@@ -24,13 +22,22 @@ type stats struct {
connected int
// To synchronize concurrent access.
mutex sync.Mutex
+ // Formats interrupt-driven stats output.
+ formatter interruptMessageFormatter
+ // Controls how long interrupt output remains visible.
+ interruptPause time.Duration
}
-func newTailStats(servers int) *stats {
+func newTailStats(servers int, formatter interruptMessageFormatter, interruptPause time.Duration) *stats {
+ if interruptPause <= 0 {
+ interruptPause = 3 * time.Second
+ }
return &stats{
servers: servers,
connectionsEstCh: make(chan struct{}, servers),
connected: 0,
+ formatter: formatter,
+ interruptPause: interruptPause,
}
}
@@ -84,17 +91,13 @@ func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{},
func (s *stats) printStatsDueInterrupt(messages []string) {
dlog.Client.Pause()
for i, message := range messages {
- if i > 0 && config.Client.TermColorsEnable {
- fmt.Println(color.PaintStrWithAttr(message,
- config.Client.TermColors.Client.ClientFg,
- config.Client.TermColors.Client.ClientBg,
- config.Client.TermColors.Client.ClientAttr,
- ))
+ if s.formatter != nil {
+ fmt.Println(s.formatter.FormatInterruptMessage(i, message))
continue
}
fmt.Printf(" %s\n", message)
}
- time.Sleep(time.Second * time.Duration(config.InterruptTimeoutS))
+ time.Sleep(s.interruptPause)
dlog.Client.Resume()
}
diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go
index 35c01d4..fff6646 100644
--- a/internal/clients/tailclient.go
+++ b/internal/clients/tailclient.go
@@ -24,6 +24,7 @@ func NewTailClient(args config.Args) (*TailClient, error) {
Args: args,
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: true,
+ runtime: newClientRuntimeBoundary(config.CurrentRuntime()),
},
}