summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-01-29 21:39:06 +0200
committerPaul Buetow <paul@buetow.org>2026-01-29 21:39:06 +0200
commit599bdb74efcc97e86ce6023c1ae8265b1c2ff33b (patch)
tree6642a7401de622efc90947586f4db7a3e6ccd75d
parenta32f028487c2e0b9e3144cf82d4153d1cd4a5243 (diff)
refactor: improve Go best practices compliance
- Add explicit interface satisfaction checks (var _ Interface = (*Type)(nil)) for compile-time verification: - TurboWriter implementations (DirectTurboWriter, TurboChannelWriter) - Processor implementations (GrepLineProcessor, ChannellessLineProcessor) - Parser implementations (genericParser, csvParser, genericKVParser, custom parsers, mimecastParser) - Logger implementations (file, stdout) - Handler implementations (ServerHandler, ClientHandler) - Connector implementations (Serverless, ServerConnection) - SSH callback implementations (KnownHostsCallback) - Improve error handling with context wrapping (%w): - SSH operations: GeneratePrivateRSAKey, Agent - Query parsing: Query.parse - SSH client connections: dial, session, handle methods - Fix receiver consistency: - Convert Query.String() from value to pointer receiver - Convert Outfile.String() from value to pointer receiver - Convert all KnownHostsCallback methods to pointer receivers - Convert mapCommand.Start() to pointer receiver - Reorganize file structure for better clarity: - internal/io/dlog/dlog.go: Move type definition before public functions - internal/mapr/token.go: Reorganize helper functions after public ones - Add documentation comments: - Query.String() method - Outfile.String() method - Regex.String() method - Improve config variable documentation All unit tests and integration tests pass. Amp-Thread-ID: https://ampcode.com/threads/T-019c0b08-0eeb-705d-a1f7-31bb764b659a Co-authored-by: Amp <amp@ampcode.com>
-rw-r--r--internal/clients/connectors/serverconnection.go12
-rw-r--r--internal/clients/connectors/serverless.go2
-rw-r--r--internal/clients/handlers/clienthandler.go2
-rw-r--r--internal/config/config.go6
-rw-r--r--internal/io/dlog/dlog.go156
-rw-r--r--internal/io/dlog/loggers/file.go2
-rw-r--r--internal/io/dlog/loggers/stdout.go2
-rw-r--r--internal/mapr/logformat/csv.go2
-rw-r--r--internal/mapr/logformat/custom1.go2
-rw-r--r--internal/mapr/logformat/custom2.go2
-rw-r--r--internal/mapr/logformat/generic.go2
-rw-r--r--internal/mapr/logformat/generickv.go2
-rw-r--r--internal/mapr/logformat/mimecast.go2
-rw-r--r--internal/mapr/query.go8
-rw-r--r--internal/mapr/token.go33
-rw-r--r--internal/regex/regex.go1
-rw-r--r--internal/server/handlers/channelless_adapter.go2
-rw-r--r--internal/server/handlers/lineprocessor.go3
-rw-r--r--internal/server/handlers/mapcommand.go2
-rw-r--r--internal/server/handlers/serverhandler.go2
-rw-r--r--internal/server/handlers/turbo_writer.go4
-rw-r--r--internal/ssh/client/knownhostscallback.go16
-rw-r--r--internal/ssh/ssh.go11
23 files changed, 157 insertions, 119 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 5c3d455..34d3997 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -32,6 +32,8 @@ type ServerConnection struct {
throttlingDone bool
}
+var _ Connector = (*ServerConnection)(nil)
+
// NewServerConnection returns a new DTail SSH server connection.
func NewServerConnection(server string, userName string,
authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback,
@@ -135,7 +137,7 @@ func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc,
client, err := ssh.Dial("tcp", address, c.config)
if err != nil {
- return err
+ return fmt.Errorf("failed to dial SSH connection to %s: %w", address, err)
}
defer client.Close()
@@ -149,7 +151,7 @@ func (c *ServerConnection) session(ctx context.Context, cancel context.CancelFun
dlog.Client.Debug(c.server, "Creating SSH session")
session, err := client.NewSession()
if err != nil {
- return err
+ return fmt.Errorf("failed to create SSH session for %s: %w", c.server, err)
}
defer session.Close()
return c.handle(ctx, cancel, session, throttleCh)
@@ -161,14 +163,14 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc
dlog.Client.Debug(c.server, "Creating handler for SSH session")
stdinPipe, err := session.StdinPipe()
if err != nil {
- return err
+ return fmt.Errorf("failed to get SSH session stdin pipe for %s: %w", c.server, err)
}
stdoutPipe, err := session.StdoutPipe()
if err != nil {
- return err
+ return fmt.Errorf("failed to get SSH session stdout pipe for %s: %w", c.server, err)
}
if err := session.Shell(); err != nil {
- return err
+ return fmt.Errorf("failed to start SSH shell for %s: %w", c.server, err)
}
go func() {
diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go
index 7cebf8a..eaaa770 100644
--- a/internal/clients/connectors/serverless.go
+++ b/internal/clients/connectors/serverless.go
@@ -18,6 +18,8 @@ type Serverless struct {
userName string
}
+var _ Connector = (*Serverless)(nil)
+
// NewServerless starts a new serverless session.
func NewServerless(userName string, handler handlers.Handler,
commands []string) *Serverless {
diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go
index 27ac85e..4d29429 100644
--- a/internal/clients/handlers/clienthandler.go
+++ b/internal/clients/handlers/clienthandler.go
@@ -10,6 +10,8 @@ type ClientHandler struct {
baseHandler
}
+var _ Handler = (*ClientHandler)(nil)
+
// NewClientHandler creates a new client handler.
func NewClientHandler(server string) *ClientHandler {
dlog.Client.Debug(server, "Creating new client handler")
diff --git a/internal/config/config.go b/internal/config/config.go
index ee23829..48a825c 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -25,13 +25,13 @@ const (
DefaultHealthCheckLogger string = "none"
)
-// Client holds a DTail client configuration.
+// Client holds DTail client configuration.
var Client *ClientConfig
-// Server holds a DTail server configuration.
+// Server holds DTail server configuration.
var Server *ServerConfig
-// Common holds common configs of both both, client and server.
+// Common holds configuration common to both client and server.
var Common *CommonConfig
// Setup the DTail configuration.
diff --git a/internal/io/dlog/dlog.go b/internal/io/dlog/dlog.go
index 180d2e4..951bedc 100644
--- a/internal/io/dlog/dlog.go
+++ b/internal/io/dlog/dlog.go
@@ -31,36 +31,6 @@ var Common *DLog
var mutex sync.Mutex
var started bool
-// Start logger(s).
-func Start(ctx context.Context, wg *sync.WaitGroup, sourceProcess source.Source) {
- mutex.Lock()
- defer mutex.Unlock()
-
- if started {
- Common.FatalPanic("Logger already started")
- }
-
- Client = new(sourceProcess, source.Client)
- Server = new(sourceProcess, source.Server)
- Common = Client
- if sourceProcess == source.Server {
- Common = Server
- }
-
- var wg2 sync.WaitGroup
- wg2.Add(2)
- go Client.start(ctx, &wg2)
- go Server.start(ctx, &wg2)
-
- go rotation(ctx)
- go func() {
- wg2.Wait()
- wg.Done()
- }()
-
- started = true
-}
-
// DLog is the DTail logger.
type DLog struct {
logger loggers.Logger
@@ -94,62 +64,43 @@ func new(sourceProcess, sourcePackage source.Source) *DLog {
}
}
-func (d *DLog) start(ctx context.Context, wg *sync.WaitGroup) {
- defer wg.Done()
- var wg2 sync.WaitGroup
- wg2.Add(1)
- d.logger.Start(ctx, &wg2)
- <-ctx.Done()
- wg2.Wait()
-}
+// Start logger(s).
+func Start(ctx context.Context, wg *sync.WaitGroup, sourceProcess source.Source) {
+ mutex.Lock()
+ defer mutex.Unlock()
-func (d *DLog) log(level level, args []interface{}) string {
- if d.maxLevel < level {
- return ""
+ if started {
+ Common.FatalPanic("Logger already started")
}
- sb := pool.BuilderBuffer.Get().(*strings.Builder)
- defer pool.RecycleBuilderBuffer(sb)
- now := time.Now()
- switch d.sourceProcess {
- case source.Client:
- sb.WriteString(d.sourcePackage.String())
- sb.WriteString(protocol.FieldDelimiter)
- sb.WriteString(d.hostname)
- sb.WriteString(protocol.FieldDelimiter)
- sb.WriteString(level.String())
- default:
- sb.WriteString(level.String())
- sb.WriteString(protocol.FieldDelimiter)
- sb.WriteString(now.Format("0102-150405"))
+ Client = new(sourceProcess, source.Client)
+ Server = new(sourceProcess, source.Server)
+ Common = Client
+ if sourceProcess == source.Server {
+ Common = Server
}
- sb.WriteString(protocol.FieldDelimiter)
- d.writeArgStrings(sb, args)
- message := sb.String()
- if !config.Client.TermColorsEnable || !d.logger.SupportsColors() {
- d.logger.Log(now, message)
- return message
- }
+ var wg2 sync.WaitGroup
+ wg2.Add(2)
+ go Client.start(ctx, &wg2)
+ go Server.start(ctx, &wg2)
- d.logger.LogWithColors(now, message, brush.Colorfy(message))
- return message
+ go rotation(ctx)
+ go func() {
+ wg2.Wait()
+ wg.Done()
+ }()
+
+ started = true
}
-func (d *DLog) writeArgStrings(sb *strings.Builder, args []interface{}) {
- for i, arg := range args {
- if i > 0 {
- sb.WriteString(protocol.FieldDelimiter)
- }
- switch v := arg.(type) {
- case string:
- sb.WriteString(v)
- case error:
- sb.WriteString(v.Error())
- default:
- sb.WriteString(fmt.Sprintf("%v", v))
- }
- }
+func (d *DLog) start(ctx context.Context, wg *sync.WaitGroup) {
+ defer wg.Done()
+ var wg2 sync.WaitGroup
+ wg2.Add(1)
+ d.logger.Start(ctx, &wg2)
+ <-ctx.Done()
+ wg2.Wait()
}
// FatalPanic terminates the process with a fatal error.
@@ -278,3 +229,52 @@ func (d *DLog) Pause() { d.logger.Pause() }
// Resume the logging.
func (d *DLog) Resume() { d.logger.Resume() }
+
+func (d *DLog) log(level level, args []interface{}) string {
+ if d.maxLevel < level {
+ return ""
+ }
+ sb := pool.BuilderBuffer.Get().(*strings.Builder)
+ defer pool.RecycleBuilderBuffer(sb)
+ now := time.Now()
+
+ switch d.sourceProcess {
+ case source.Client:
+ sb.WriteString(d.sourcePackage.String())
+ sb.WriteString(protocol.FieldDelimiter)
+ sb.WriteString(d.hostname)
+ sb.WriteString(protocol.FieldDelimiter)
+ sb.WriteString(level.String())
+ default:
+ sb.WriteString(level.String())
+ sb.WriteString(protocol.FieldDelimiter)
+ sb.WriteString(now.Format("0102-150405"))
+ }
+ sb.WriteString(protocol.FieldDelimiter)
+ d.writeArgStrings(sb, args)
+
+ message := sb.String()
+ if !config.Client.TermColorsEnable || !d.logger.SupportsColors() {
+ d.logger.Log(now, message)
+ return message
+ }
+
+ d.logger.LogWithColors(now, message, brush.Colorfy(message))
+ return message
+}
+
+func (d *DLog) writeArgStrings(sb *strings.Builder, args []interface{}) {
+ for i, arg := range args {
+ if i > 0 {
+ sb.WriteString(protocol.FieldDelimiter)
+ }
+ switch v := arg.(type) {
+ case string:
+ sb.WriteString(v)
+ case error:
+ sb.WriteString(v.Error())
+ default:
+ sb.WriteString(fmt.Sprintf("%v", v))
+ }
+ }
+}
diff --git a/internal/io/dlog/loggers/file.go b/internal/io/dlog/loggers/file.go
index 8e567bc..5ac8d9e 100644
--- a/internal/io/dlog/loggers/file.go
+++ b/internal/io/dlog/loggers/file.go
@@ -32,6 +32,8 @@ type file struct {
strategy Strategy
}
+var _ Logger = (*file)(nil)
+
func newFile(strategy Strategy) *file {
return &file{
bufferCh: make(chan *fileMessageBuf, runtime.NumCPU()*100),
diff --git a/internal/io/dlog/loggers/stdout.go b/internal/io/dlog/loggers/stdout.go
index b024243..a2575c8 100644
--- a/internal/io/dlog/loggers/stdout.go
+++ b/internal/io/dlog/loggers/stdout.go
@@ -13,6 +13,8 @@ type stdout struct {
mutex sync.Mutex
}
+var _ Logger = (*stdout)(nil)
+
func newStdout() *stdout {
return &stdout{
pauseCh: make(chan struct{}),
diff --git a/internal/mapr/logformat/csv.go b/internal/mapr/logformat/csv.go
index ea85ca9..b8f565c 100644
--- a/internal/mapr/logformat/csv.go
+++ b/internal/mapr/logformat/csv.go
@@ -13,6 +13,8 @@ type csvParser struct {
hasHeader bool
}
+var _ Parser = (*csvParser)(nil)
+
func newCSVParser(hostname, timeZoneName string, timeZoneOffset int) (*csvParser, error) {
defaultParser, err := newDefaultParser(hostname, timeZoneName, timeZoneOffset)
if err != nil {
diff --git a/internal/mapr/logformat/custom1.go b/internal/mapr/logformat/custom1.go
index 7229f3e..05e0867 100644
--- a/internal/mapr/logformat/custom1.go
+++ b/internal/mapr/logformat/custom1.go
@@ -7,6 +7,8 @@ var ErrCustom1NotImplemented error = errors.New("custom1 log format is not imple
// Template for creating a custom log format.
type custom1Parser struct{}
+var _ Parser = (*custom1Parser)(nil)
+
func newCustom1Parser(hostname, timeZoneName string, timeZoneOffset int) (*custom1Parser, error) {
return &custom1Parser{}, ErrCustom1NotImplemented
}
diff --git a/internal/mapr/logformat/custom2.go b/internal/mapr/logformat/custom2.go
index 262c721..cc8d5b9 100644
--- a/internal/mapr/logformat/custom2.go
+++ b/internal/mapr/logformat/custom2.go
@@ -7,6 +7,8 @@ var ErrCustom2NotImplemented error = errors.New("custom2 log format is not imple
// Template for creating a custom log format.
type custom2Parser struct{}
+var _ Parser = (*custom2Parser)(nil)
+
func newCustom2Parser(hostname, timeZoneName string, timeZoneOffset int) (*custom2Parser, error) {
return &custom2Parser{}, ErrCustom2NotImplemented
}
diff --git a/internal/mapr/logformat/generic.go b/internal/mapr/logformat/generic.go
index 32d9b4a..1350eff 100644
--- a/internal/mapr/logformat/generic.go
+++ b/internal/mapr/logformat/generic.go
@@ -4,6 +4,8 @@ type genericParser struct {
defaultParser
}
+var _ Parser = (*genericParser)(nil)
+
func newGenericParser(hostname, timeZoneName string, timeZoneOffset int) (*genericParser, error) {
defaultParser, err := newDefaultParser(hostname, timeZoneName, timeZoneOffset)
if err != nil {
diff --git a/internal/mapr/logformat/generickv.go b/internal/mapr/logformat/generickv.go
index 9c3de92..bd9aad5 100644
--- a/internal/mapr/logformat/generickv.go
+++ b/internal/mapr/logformat/generickv.go
@@ -10,6 +10,8 @@ type genericKVParser struct {
defaultParser
}
+var _ Parser = (*genericKVParser)(nil)
+
func newGenericKVParser(hostname, timeZoneName string, timeZoneOffset int) (*genericKVParser, error) {
defaultParser, err := newDefaultParser(hostname, timeZoneName, timeZoneOffset)
if err != nil {
diff --git a/internal/mapr/logformat/mimecast.go b/internal/mapr/logformat/mimecast.go
index cf6b333..84e1e93 100644
--- a/internal/mapr/logformat/mimecast.go
+++ b/internal/mapr/logformat/mimecast.go
@@ -10,6 +10,8 @@ var ErrMimecastNotAvailable error = errors.New("The mimecast logformat is not av
type mimecastParser struct{}
+var _ Parser = (*mimecastParser)(nil)
+
func newMimecastParser(hostname, timeZoneName string, timeZoneOffset int) (*mimecastParser, error) {
return &mimecastParser{}, ErrMimecastNotAvailable
}
diff --git a/internal/mapr/query.go b/internal/mapr/query.go
index 139f04c..06a8dc2 100644
--- a/internal/mapr/query.go
+++ b/internal/mapr/query.go
@@ -19,7 +19,8 @@ type Outfile struct {
AppendMode bool
}
-func (o Outfile) String() string {
+// String returns the string representation of Outfile.
+func (o *Outfile) String() string {
return fmt.Sprintf("Outfile(FilePath:%v,AppendMode:%v)", o.FilePath, o.AppendMode)
}
@@ -41,7 +42,8 @@ type Query struct {
LogFormat string
}
-func (q Query) String() string {
+// String returns the string representation of Query.
+func (q *Query) String() string {
return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,Set:%vGroupBy:%v,"+
"GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,"+
"RawQuery:%s,tokens:%v,LogFormat:%s)",
@@ -95,7 +97,7 @@ func (q *Query) Has(what string) bool {
func (q *Query) parse(tokens []token) error {
if _, err := q.parseTokens(tokens); err != nil {
- return err
+ return fmt.Errorf("failed to parse query tokens: %w", err)
}
if len(q.Select) < 1 {
diff --git a/internal/mapr/token.go b/internal/mapr/token.go
index 77362f7..b9b02f8 100644
--- a/internal/mapr/token.go
+++ b/internal/mapr/token.go
@@ -14,22 +14,7 @@ type token struct {
quotesStripped bool
}
-func (t token) isKeyword() bool {
- if !t.isBareword {
- return false
- }
- for _, keyword := range keywords {
- if strings.ToLower(t.str) == keyword {
- return true
- }
- }
- return false
-}
-
-func (t token) String() string {
- return t.str
-}
-
+// tokenize parses a query string into tokens.
func tokenize(queryStr string) []token {
var tokens []token
for i, part := range strings.Split(queryStr, "\"") {
@@ -105,3 +90,19 @@ func tokensConsumeOptional(tokens []token, optional string) []token {
}
return tokens
}
+
+func (t token) isKeyword() bool {
+ if !t.isBareword {
+ return false
+ }
+ for _, keyword := range keywords {
+ if strings.ToLower(t.str) == keyword {
+ return true
+ }
+ }
+ return false
+}
+
+func (t token) String() string {
+ return t.str
+}
diff --git a/internal/regex/regex.go b/internal/regex/regex.go
index b817bc4..a34f7c1 100644
--- a/internal/regex/regex.go
+++ b/internal/regex/regex.go
@@ -23,6 +23,7 @@ type Regex struct {
literalBytes []byte // literal bytes for byte matching
}
+// String returns the string representation of Regex.
func (r Regex) String() string {
return fmt.Sprintf("Regex(regexStr:%s,flags:%s,initialized:%t,re==nil:%t,isLiteral:%t)",
r.regexStr, r.flags, r.initialized, r.re == nil, r.isLiteral)
diff --git a/internal/server/handlers/channelless_adapter.go b/internal/server/handlers/channelless_adapter.go
index a950408..40c072f 100644
--- a/internal/server/handlers/channelless_adapter.go
+++ b/internal/server/handlers/channelless_adapter.go
@@ -13,6 +13,8 @@ type ChannellessLineProcessor struct {
lineCount uint64
}
+var _ line.Processor = (*ChannellessLineProcessor)(nil)
+
// NewChannellessLineProcessor creates a processor that sends lines to the existing channel
func NewChannellessLineProcessor(lines chan<- *line.Line, globID string) *ChannellessLineProcessor {
return &ChannellessLineProcessor{
diff --git a/internal/server/handlers/lineprocessor.go b/internal/server/handlers/lineprocessor.go
index f75b85b..9bbf7e1 100644
--- a/internal/server/handlers/lineprocessor.go
+++ b/internal/server/handlers/lineprocessor.go
@@ -6,6 +6,7 @@ import (
"io"
"sync"
+ "github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/pool"
"github.com/mimecast/dtail/internal/protocol"
)
@@ -28,6 +29,8 @@ type GrepLineProcessor struct {
bytesWritten uint64
}
+var _ line.Processor = (*GrepLineProcessor)(nil)
+
// HandlerWriter adapts a ServerHandler to implement io.Writer
type HandlerWriter struct {
handler *ServerHandler
diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go
index 83c4c75..a4fda97 100644
--- a/internal/server/handlers/mapcommand.go
+++ b/internal/server/handlers/mapcommand.go
@@ -45,7 +45,7 @@ func newMapCommand(serverHandler *ServerHandler, argc int,
return m, aggregate, nil, nil
}
-func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) {
+func (m *mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) {
if m.turboAggregate != nil {
m.turboAggregate.Start(ctx, aggregatedMessages)
} else {
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index df227ab..645e2e9 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -26,6 +26,8 @@ type ServerHandler struct {
pendingFiles int32
}
+var _ Handler = (*ServerHandler)(nil)
+
// NewServerHandler returns the server handler.
func NewServerHandler(user *user.User, catLimiter,
tailLimiter chan struct{}) *ServerHandler {
diff --git a/internal/server/handlers/turbo_writer.go b/internal/server/handlers/turbo_writer.go
index d8ee2ad..62225bd 100644
--- a/internal/server/handlers/turbo_writer.go
+++ b/internal/server/handlers/turbo_writer.go
@@ -40,6 +40,8 @@ type DirectTurboWriter struct {
bytesWritten uint64
}
+var _ TurboWriter = (*DirectTurboWriter)(nil)
+
// NewDirectTurboWriter creates a new turbo writer
func NewDirectTurboWriter(writer io.Writer, hostname string, plain, serverless bool) *DirectTurboWriter {
return &DirectTurboWriter{
@@ -252,6 +254,8 @@ type TurboChannelWriter struct {
bytesWritten uint64
}
+var _ TurboWriter = (*TurboChannelWriter)(nil)
+
// NewTurboChannelWriter creates a writer that sends to a turbo channel
func NewTurboChannelWriter(channel chan<- []byte, hostname string, plain, serverless bool) *TurboChannelWriter {
return &TurboChannelWriter{
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go
index fe3543c..9c73864 100644
--- a/internal/ssh/client/knownhostscallback.go
+++ b/internal/ssh/client/knownhostscallback.go
@@ -45,6 +45,8 @@ type KnownHostsCallback struct {
mutex *sync.Mutex
}
+var _ HostKeyCallback = (*KnownHostsCallback)(nil)
+
// NewKnownHostsCallback returns a new wrapper.
func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
throttleCh chan struct{}) (HostKeyCallback, error) {
@@ -63,11 +65,11 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
if trustAllHosts {
close(c.trustAllHostsCh)
}
- return c, nil
+ return &c, nil
}
// Wrap the host key callback.
-func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
+func (c *KnownHostsCallback) Wrap() ssh.HostKeyCallback {
return func(server string, remote net.Addr, key ssh.PublicKey) error {
// Parse known_hosts file
knownHostsCb, err := knownhosts.New(c.knownHostsPath)
@@ -113,7 +115,7 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
// PromptAddHosts prompts a question to the user whether unknown hosts should
// be added to the known hosts or not.
-func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
+func (c *KnownHostsCallback) PromptAddHosts(ctx context.Context) {
var hosts []unknownHost
for {
// Check whether there is a unknown host
@@ -138,7 +140,7 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
}
}
-func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
+func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
var servers []string
for _, host := range hosts {
servers = append(servers, host.server)
@@ -212,7 +214,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
p.Ask()
}
-func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
+func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) {
tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath)
newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
@@ -265,14 +267,14 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
}
}
-func (c KnownHostsCallback) dontTrustHosts(hosts []unknownHost) {
+func (c *KnownHostsCallback) dontTrustHosts(hosts []unknownHost) {
for _, unknown := range hosts {
unknown.responseCh <- dontTrustHost
}
}
// Untrusted returns true if the host is not trusted. False otherwise.
-func (c KnownHostsCallback) Untrusted(server string) bool {
+func (c *KnownHostsCallback) Untrusted(server string) bool {
c.mutex.Lock()
defer c.mutex.Unlock()
_, ok := c.untrustedHosts[server]
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 9c2dcb8..32e01b3 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -21,11 +21,10 @@ import (
func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, size)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to generate RSA key: %w", err)
}
- err = privateKey.Validate()
- if err != nil {
- return nil, err
+ if err = privateKey.Validate(); err != nil {
+ return nil, fmt.Errorf("failed to validate generated RSA key: %w", err)
}
return privateKey, nil
}
@@ -46,12 +45,12 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
func Agent() (gossh.AuthMethod, error) {
sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to connect to SSH agent: %w", err)
}
agentClient := agent.NewClient(sshAgent)
keys, err := agentClient.List()
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to list SSH agent keys: %w", err)
}
for i, key := range keys {
dlog.Common.Debug("Public key", i, key)