diff options
23 files changed, 157 insertions, 119 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 5c3d455..34d3997 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -32,6 +32,8 @@ type ServerConnection struct { throttlingDone bool } +var _ Connector = (*ServerConnection)(nil) + // NewServerConnection returns a new DTail SSH server connection. func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, @@ -135,7 +137,7 @@ func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc, client, err := ssh.Dial("tcp", address, c.config) if err != nil { - return err + return fmt.Errorf("failed to dial SSH connection to %s: %w", address, err) } defer client.Close() @@ -149,7 +151,7 @@ func (c *ServerConnection) session(ctx context.Context, cancel context.CancelFun dlog.Client.Debug(c.server, "Creating SSH session") session, err := client.NewSession() if err != nil { - return err + return fmt.Errorf("failed to create SSH session for %s: %w", c.server, err) } defer session.Close() return c.handle(ctx, cancel, session, throttleCh) @@ -161,14 +163,14 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc dlog.Client.Debug(c.server, "Creating handler for SSH session") stdinPipe, err := session.StdinPipe() if err != nil { - return err + return fmt.Errorf("failed to get SSH session stdin pipe for %s: %w", c.server, err) } stdoutPipe, err := session.StdoutPipe() if err != nil { - return err + return fmt.Errorf("failed to get SSH session stdout pipe for %s: %w", c.server, err) } if err := session.Shell(); err != nil { - return err + return fmt.Errorf("failed to start SSH shell for %s: %w", c.server, err) } go func() { diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go index 7cebf8a..eaaa770 100644 --- a/internal/clients/connectors/serverless.go +++ b/internal/clients/connectors/serverless.go @@ -18,6 +18,8 @@ type Serverless struct { userName string } +var _ Connector = (*Serverless)(nil) + // NewServerless starts a new serverless session. func NewServerless(userName string, handler handlers.Handler, commands []string) *Serverless { diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go index 27ac85e..4d29429 100644 --- a/internal/clients/handlers/clienthandler.go +++ b/internal/clients/handlers/clienthandler.go @@ -10,6 +10,8 @@ type ClientHandler struct { baseHandler } +var _ Handler = (*ClientHandler)(nil) + // NewClientHandler creates a new client handler. func NewClientHandler(server string) *ClientHandler { dlog.Client.Debug(server, "Creating new client handler") diff --git a/internal/config/config.go b/internal/config/config.go index ee23829..48a825c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,13 +25,13 @@ const ( DefaultHealthCheckLogger string = "none" ) -// Client holds a DTail client configuration. +// Client holds DTail client configuration. var Client *ClientConfig -// Server holds a DTail server configuration. +// Server holds DTail server configuration. var Server *ServerConfig -// Common holds common configs of both both, client and server. +// Common holds configuration common to both client and server. var Common *CommonConfig // Setup the DTail configuration. diff --git a/internal/io/dlog/dlog.go b/internal/io/dlog/dlog.go index 180d2e4..951bedc 100644 --- a/internal/io/dlog/dlog.go +++ b/internal/io/dlog/dlog.go @@ -31,36 +31,6 @@ var Common *DLog var mutex sync.Mutex var started bool -// Start logger(s). -func Start(ctx context.Context, wg *sync.WaitGroup, sourceProcess source.Source) { - mutex.Lock() - defer mutex.Unlock() - - if started { - Common.FatalPanic("Logger already started") - } - - Client = new(sourceProcess, source.Client) - Server = new(sourceProcess, source.Server) - Common = Client - if sourceProcess == source.Server { - Common = Server - } - - var wg2 sync.WaitGroup - wg2.Add(2) - go Client.start(ctx, &wg2) - go Server.start(ctx, &wg2) - - go rotation(ctx) - go func() { - wg2.Wait() - wg.Done() - }() - - started = true -} - // DLog is the DTail logger. type DLog struct { logger loggers.Logger @@ -94,62 +64,43 @@ func new(sourceProcess, sourcePackage source.Source) *DLog { } } -func (d *DLog) start(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() - var wg2 sync.WaitGroup - wg2.Add(1) - d.logger.Start(ctx, &wg2) - <-ctx.Done() - wg2.Wait() -} +// Start logger(s). +func Start(ctx context.Context, wg *sync.WaitGroup, sourceProcess source.Source) { + mutex.Lock() + defer mutex.Unlock() -func (d *DLog) log(level level, args []interface{}) string { - if d.maxLevel < level { - return "" + if started { + Common.FatalPanic("Logger already started") } - sb := pool.BuilderBuffer.Get().(*strings.Builder) - defer pool.RecycleBuilderBuffer(sb) - now := time.Now() - switch d.sourceProcess { - case source.Client: - sb.WriteString(d.sourcePackage.String()) - sb.WriteString(protocol.FieldDelimiter) - sb.WriteString(d.hostname) - sb.WriteString(protocol.FieldDelimiter) - sb.WriteString(level.String()) - default: - sb.WriteString(level.String()) - sb.WriteString(protocol.FieldDelimiter) - sb.WriteString(now.Format("0102-150405")) + Client = new(sourceProcess, source.Client) + Server = new(sourceProcess, source.Server) + Common = Client + if sourceProcess == source.Server { + Common = Server } - sb.WriteString(protocol.FieldDelimiter) - d.writeArgStrings(sb, args) - message := sb.String() - if !config.Client.TermColorsEnable || !d.logger.SupportsColors() { - d.logger.Log(now, message) - return message - } + var wg2 sync.WaitGroup + wg2.Add(2) + go Client.start(ctx, &wg2) + go Server.start(ctx, &wg2) - d.logger.LogWithColors(now, message, brush.Colorfy(message)) - return message + go rotation(ctx) + go func() { + wg2.Wait() + wg.Done() + }() + + started = true } -func (d *DLog) writeArgStrings(sb *strings.Builder, args []interface{}) { - for i, arg := range args { - if i > 0 { - sb.WriteString(protocol.FieldDelimiter) - } - switch v := arg.(type) { - case string: - sb.WriteString(v) - case error: - sb.WriteString(v.Error()) - default: - sb.WriteString(fmt.Sprintf("%v", v)) - } - } +func (d *DLog) start(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + var wg2 sync.WaitGroup + wg2.Add(1) + d.logger.Start(ctx, &wg2) + <-ctx.Done() + wg2.Wait() } // FatalPanic terminates the process with a fatal error. @@ -278,3 +229,52 @@ func (d *DLog) Pause() { d.logger.Pause() } // Resume the logging. func (d *DLog) Resume() { d.logger.Resume() } + +func (d *DLog) log(level level, args []interface{}) string { + if d.maxLevel < level { + return "" + } + sb := pool.BuilderBuffer.Get().(*strings.Builder) + defer pool.RecycleBuilderBuffer(sb) + now := time.Now() + + switch d.sourceProcess { + case source.Client: + sb.WriteString(d.sourcePackage.String()) + sb.WriteString(protocol.FieldDelimiter) + sb.WriteString(d.hostname) + sb.WriteString(protocol.FieldDelimiter) + sb.WriteString(level.String()) + default: + sb.WriteString(level.String()) + sb.WriteString(protocol.FieldDelimiter) + sb.WriteString(now.Format("0102-150405")) + } + sb.WriteString(protocol.FieldDelimiter) + d.writeArgStrings(sb, args) + + message := sb.String() + if !config.Client.TermColorsEnable || !d.logger.SupportsColors() { + d.logger.Log(now, message) + return message + } + + d.logger.LogWithColors(now, message, brush.Colorfy(message)) + return message +} + +func (d *DLog) writeArgStrings(sb *strings.Builder, args []interface{}) { + for i, arg := range args { + if i > 0 { + sb.WriteString(protocol.FieldDelimiter) + } + switch v := arg.(type) { + case string: + sb.WriteString(v) + case error: + sb.WriteString(v.Error()) + default: + sb.WriteString(fmt.Sprintf("%v", v)) + } + } +} diff --git a/internal/io/dlog/loggers/file.go b/internal/io/dlog/loggers/file.go index 8e567bc..5ac8d9e 100644 --- a/internal/io/dlog/loggers/file.go +++ b/internal/io/dlog/loggers/file.go @@ -32,6 +32,8 @@ type file struct { strategy Strategy } +var _ Logger = (*file)(nil) + func newFile(strategy Strategy) *file { return &file{ bufferCh: make(chan *fileMessageBuf, runtime.NumCPU()*100), diff --git a/internal/io/dlog/loggers/stdout.go b/internal/io/dlog/loggers/stdout.go index b024243..a2575c8 100644 --- a/internal/io/dlog/loggers/stdout.go +++ b/internal/io/dlog/loggers/stdout.go @@ -13,6 +13,8 @@ type stdout struct { mutex sync.Mutex } +var _ Logger = (*stdout)(nil) + func newStdout() *stdout { return &stdout{ pauseCh: make(chan struct{}), diff --git a/internal/mapr/logformat/csv.go b/internal/mapr/logformat/csv.go index ea85ca9..b8f565c 100644 --- a/internal/mapr/logformat/csv.go +++ b/internal/mapr/logformat/csv.go @@ -13,6 +13,8 @@ type csvParser struct { hasHeader bool } +var _ Parser = (*csvParser)(nil) + func newCSVParser(hostname, timeZoneName string, timeZoneOffset int) (*csvParser, error) { defaultParser, err := newDefaultParser(hostname, timeZoneName, timeZoneOffset) if err != nil { diff --git a/internal/mapr/logformat/custom1.go b/internal/mapr/logformat/custom1.go index 7229f3e..05e0867 100644 --- a/internal/mapr/logformat/custom1.go +++ b/internal/mapr/logformat/custom1.go @@ -7,6 +7,8 @@ var ErrCustom1NotImplemented error = errors.New("custom1 log format is not imple // Template for creating a custom log format. type custom1Parser struct{} +var _ Parser = (*custom1Parser)(nil) + func newCustom1Parser(hostname, timeZoneName string, timeZoneOffset int) (*custom1Parser, error) { return &custom1Parser{}, ErrCustom1NotImplemented } diff --git a/internal/mapr/logformat/custom2.go b/internal/mapr/logformat/custom2.go index 262c721..cc8d5b9 100644 --- a/internal/mapr/logformat/custom2.go +++ b/internal/mapr/logformat/custom2.go @@ -7,6 +7,8 @@ var ErrCustom2NotImplemented error = errors.New("custom2 log format is not imple // Template for creating a custom log format. type custom2Parser struct{} +var _ Parser = (*custom2Parser)(nil) + func newCustom2Parser(hostname, timeZoneName string, timeZoneOffset int) (*custom2Parser, error) { return &custom2Parser{}, ErrCustom2NotImplemented } diff --git a/internal/mapr/logformat/generic.go b/internal/mapr/logformat/generic.go index 32d9b4a..1350eff 100644 --- a/internal/mapr/logformat/generic.go +++ b/internal/mapr/logformat/generic.go @@ -4,6 +4,8 @@ type genericParser struct { defaultParser } +var _ Parser = (*genericParser)(nil) + func newGenericParser(hostname, timeZoneName string, timeZoneOffset int) (*genericParser, error) { defaultParser, err := newDefaultParser(hostname, timeZoneName, timeZoneOffset) if err != nil { diff --git a/internal/mapr/logformat/generickv.go b/internal/mapr/logformat/generickv.go index 9c3de92..bd9aad5 100644 --- a/internal/mapr/logformat/generickv.go +++ b/internal/mapr/logformat/generickv.go @@ -10,6 +10,8 @@ type genericKVParser struct { defaultParser } +var _ Parser = (*genericKVParser)(nil) + func newGenericKVParser(hostname, timeZoneName string, timeZoneOffset int) (*genericKVParser, error) { defaultParser, err := newDefaultParser(hostname, timeZoneName, timeZoneOffset) if err != nil { diff --git a/internal/mapr/logformat/mimecast.go b/internal/mapr/logformat/mimecast.go index cf6b333..84e1e93 100644 --- a/internal/mapr/logformat/mimecast.go +++ b/internal/mapr/logformat/mimecast.go @@ -10,6 +10,8 @@ var ErrMimecastNotAvailable error = errors.New("The mimecast logformat is not av type mimecastParser struct{} +var _ Parser = (*mimecastParser)(nil) + func newMimecastParser(hostname, timeZoneName string, timeZoneOffset int) (*mimecastParser, error) { return &mimecastParser{}, ErrMimecastNotAvailable } diff --git a/internal/mapr/query.go b/internal/mapr/query.go index 139f04c..06a8dc2 100644 --- a/internal/mapr/query.go +++ b/internal/mapr/query.go @@ -19,7 +19,8 @@ type Outfile struct { AppendMode bool } -func (o Outfile) String() string { +// String returns the string representation of Outfile. +func (o *Outfile) String() string { return fmt.Sprintf("Outfile(FilePath:%v,AppendMode:%v)", o.FilePath, o.AppendMode) } @@ -41,7 +42,8 @@ type Query struct { LogFormat string } -func (q Query) String() string { +// String returns the string representation of Query. +func (q *Query) String() string { return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,Set:%vGroupBy:%v,"+ "GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,"+ "RawQuery:%s,tokens:%v,LogFormat:%s)", @@ -95,7 +97,7 @@ func (q *Query) Has(what string) bool { func (q *Query) parse(tokens []token) error { if _, err := q.parseTokens(tokens); err != nil { - return err + return fmt.Errorf("failed to parse query tokens: %w", err) } if len(q.Select) < 1 { diff --git a/internal/mapr/token.go b/internal/mapr/token.go index 77362f7..b9b02f8 100644 --- a/internal/mapr/token.go +++ b/internal/mapr/token.go @@ -14,22 +14,7 @@ type token struct { quotesStripped bool } -func (t token) isKeyword() bool { - if !t.isBareword { - return false - } - for _, keyword := range keywords { - if strings.ToLower(t.str) == keyword { - return true - } - } - return false -} - -func (t token) String() string { - return t.str -} - +// tokenize parses a query string into tokens. func tokenize(queryStr string) []token { var tokens []token for i, part := range strings.Split(queryStr, "\"") { @@ -105,3 +90,19 @@ func tokensConsumeOptional(tokens []token, optional string) []token { } return tokens } + +func (t token) isKeyword() bool { + if !t.isBareword { + return false + } + for _, keyword := range keywords { + if strings.ToLower(t.str) == keyword { + return true + } + } + return false +} + +func (t token) String() string { + return t.str +} diff --git a/internal/regex/regex.go b/internal/regex/regex.go index b817bc4..a34f7c1 100644 --- a/internal/regex/regex.go +++ b/internal/regex/regex.go @@ -23,6 +23,7 @@ type Regex struct { literalBytes []byte // literal bytes for byte matching } +// String returns the string representation of Regex. func (r Regex) String() string { return fmt.Sprintf("Regex(regexStr:%s,flags:%s,initialized:%t,re==nil:%t,isLiteral:%t)", r.regexStr, r.flags, r.initialized, r.re == nil, r.isLiteral) diff --git a/internal/server/handlers/channelless_adapter.go b/internal/server/handlers/channelless_adapter.go index a950408..40c072f 100644 --- a/internal/server/handlers/channelless_adapter.go +++ b/internal/server/handlers/channelless_adapter.go @@ -13,6 +13,8 @@ type ChannellessLineProcessor struct { lineCount uint64 } +var _ line.Processor = (*ChannellessLineProcessor)(nil) + // NewChannellessLineProcessor creates a processor that sends lines to the existing channel func NewChannellessLineProcessor(lines chan<- *line.Line, globID string) *ChannellessLineProcessor { return &ChannellessLineProcessor{ diff --git a/internal/server/handlers/lineprocessor.go b/internal/server/handlers/lineprocessor.go index f75b85b..9bbf7e1 100644 --- a/internal/server/handlers/lineprocessor.go +++ b/internal/server/handlers/lineprocessor.go @@ -6,6 +6,7 @@ import ( "io" "sync" + "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/io/pool" "github.com/mimecast/dtail/internal/protocol" ) @@ -28,6 +29,8 @@ type GrepLineProcessor struct { bytesWritten uint64 } +var _ line.Processor = (*GrepLineProcessor)(nil) + // HandlerWriter adapts a ServerHandler to implement io.Writer type HandlerWriter struct { handler *ServerHandler diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go index 83c4c75..a4fda97 100644 --- a/internal/server/handlers/mapcommand.go +++ b/internal/server/handlers/mapcommand.go @@ -45,7 +45,7 @@ func newMapCommand(serverHandler *ServerHandler, argc int, return m, aggregate, nil, nil } -func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) { +func (m *mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) { if m.turboAggregate != nil { m.turboAggregate.Start(ctx, aggregatedMessages) } else { diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index df227ab..645e2e9 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -26,6 +26,8 @@ type ServerHandler struct { pendingFiles int32 } +var _ Handler = (*ServerHandler)(nil) + // NewServerHandler returns the server handler. func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler { diff --git a/internal/server/handlers/turbo_writer.go b/internal/server/handlers/turbo_writer.go index d8ee2ad..62225bd 100644 --- a/internal/server/handlers/turbo_writer.go +++ b/internal/server/handlers/turbo_writer.go @@ -40,6 +40,8 @@ type DirectTurboWriter struct { bytesWritten uint64 } +var _ TurboWriter = (*DirectTurboWriter)(nil) + // NewDirectTurboWriter creates a new turbo writer func NewDirectTurboWriter(writer io.Writer, hostname string, plain, serverless bool) *DirectTurboWriter { return &DirectTurboWriter{ @@ -252,6 +254,8 @@ type TurboChannelWriter struct { bytesWritten uint64 } +var _ TurboWriter = (*TurboChannelWriter)(nil) + // NewTurboChannelWriter creates a writer that sends to a turbo channel func NewTurboChannelWriter(channel chan<- []byte, hostname string, plain, serverless bool) *TurboChannelWriter { return &TurboChannelWriter{ diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go index fe3543c..9c73864 100644 --- a/internal/ssh/client/knownhostscallback.go +++ b/internal/ssh/client/knownhostscallback.go @@ -45,6 +45,8 @@ type KnownHostsCallback struct { mutex *sync.Mutex } +var _ HostKeyCallback = (*KnownHostsCallback)(nil) + // NewKnownHostsCallback returns a new wrapper. func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) { @@ -63,11 +65,11 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, if trustAllHosts { close(c.trustAllHostsCh) } - return c, nil + return &c, nil } // Wrap the host key callback. -func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { +func (c *KnownHostsCallback) Wrap() ssh.HostKeyCallback { return func(server string, remote net.Addr, key ssh.PublicKey) error { // Parse known_hosts file knownHostsCb, err := knownhosts.New(c.knownHostsPath) @@ -113,7 +115,7 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { // PromptAddHosts prompts a question to the user whether unknown hosts should // be added to the known hosts or not. -func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { +func (c *KnownHostsCallback) PromptAddHosts(ctx context.Context) { var hosts []unknownHost for { // Check whether there is a unknown host @@ -138,7 +140,7 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { } } -func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { +func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) { var servers []string for _, host := range hosts { servers = append(servers, host.server) @@ -212,7 +214,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { p.Ask() } -func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { +func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) { tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath) newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) @@ -265,14 +267,14 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { } } -func (c KnownHostsCallback) dontTrustHosts(hosts []unknownHost) { +func (c *KnownHostsCallback) dontTrustHosts(hosts []unknownHost) { for _, unknown := range hosts { unknown.responseCh <- dontTrustHost } } // Untrusted returns true if the host is not trusted. False otherwise. -func (c KnownHostsCallback) Untrusted(server string) bool { +func (c *KnownHostsCallback) Untrusted(server string) bool { c.mutex.Lock() defer c.mutex.Unlock() _, ok := c.untrustedHosts[server] diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 9c2dcb8..32e01b3 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -21,11 +21,10 @@ import ( func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) { privateKey, err := rsa.GenerateKey(rand.Reader, size) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to generate RSA key: %w", err) } - err = privateKey.Validate() - if err != nil { - return nil, err + if err = privateKey.Validate(); err != nil { + return nil, fmt.Errorf("failed to validate generated RSA key: %w", err) } return privateKey, nil } @@ -46,12 +45,12 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { func Agent() (gossh.AuthMethod, error) { sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to connect to SSH agent: %w", err) } agentClient := agent.NewClient(sshAgent) keys, err := agentClient.List() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to list SSH agent keys: %w", err) } for i, key := range keys { dlog.Common.Debug("Public key", i, key) |
