diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-13 09:54:31 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-13 09:54:31 +0200 |
| commit | a5a405d79fe3d9e0c6ea081b425d36bd67d8671d (patch) | |
| tree | ff460d868a75c901bd9ea3bfbc3b0af738723a59 /internal | |
| parent | 7ce4aee16c0223cc15c9dd8d9024120069500c65 (diff) | |
task be5429a7: cover reconnect session restore
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/clients/baseclient.go | 19 | ||||
| -rw-r--r-- | internal/clients/baseclient_retry_test.go | 124 |
2 files changed, 142 insertions, 1 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 0e67ba4..71b8d02 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -42,6 +42,12 @@ type baseClient struct { sessionSpec SessionSpec // Connection maker helper. maker maker + // Optional factory override for retry/reconnect tests. + connectionFactory func(server string, authMethods []gossh.AuthMethod, + hostKeyCallback client.HostKeyCallback, sessionSpec SessionSpec, + interactive bool) connectors.Connector + // Optional sleep override for retry tests. + sleepFn func(context.Context, time.Duration) bool // Regex is the regular expresion object for line filtering Regex regex.Regex } @@ -155,7 +161,7 @@ func (c *baseClient) startConnection(ctx context.Context, i int, // Yes, we want to retry with exponential backoff and jitter. sleepDuration := jitterRetryDelay(retryDelay, retryRandom) dlog.Client.Debug(conn.Server(), "Reconnecting", "backoff", sleepDuration) - if !sleepWithContext(ctx, sleepDuration) { + if !c.sleepRetry(ctx, sleepDuration) { return } @@ -218,6 +224,10 @@ func newRetryRandom(seedOffset int) *rand.Rand { func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) connectors.Connector { + if c.connectionFactory != nil { + return c.connectionFactory(server, sshAuthMethods, hostKeyCallback, + c.sessionSpec, c.Args.InteractiveQuery) + } if c.Args.Serverless { return connectors.NewServerless(c.UserName, c.maker.makeHandler(server), c.maker.makeCommands(), c.sessionSpec, c.Args.InteractiveQuery, c.runtime) @@ -227,3 +237,10 @@ func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMe c.sessionSpec, c.Args.InteractiveQuery, c.Args.SSHPrivateKeyFilePath, c.Args.NoAuthKey, c.runtime) } + +func (c *baseClient) sleepRetry(ctx context.Context, delay time.Duration) bool { + if c.sleepFn != nil { + return c.sleepFn(ctx, delay) + } + return sleepWithContext(ctx, delay) +} diff --git a/internal/clients/baseclient_retry_test.go b/internal/clients/baseclient_retry_test.go index 323ceae..7a3c571 100644 --- a/internal/clients/baseclient_retry_test.go +++ b/internal/clients/baseclient_retry_test.go @@ -5,6 +5,14 @@ import ( "math/rand" "testing" "time" + + "github.com/mimecast/dtail/internal/clients/connectors" + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/omode" + sshclient "github.com/mimecast/dtail/internal/ssh/client" + + gossh "golang.org/x/crypto/ssh" ) func TestNextRetryDelay(t *testing.T) { @@ -56,3 +64,119 @@ func TestSleepWithContextCancellation(t *testing.T) { t.Fatalf("sleepWithContext took too long to exit on canceled context") } } + +func TestStartConnectionReconnectsWithLatestSessionSpec(t *testing.T) { + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + dlog.Client = originalLogger + }) + + first := &retryTestConnector{ + server: "srv1", + handler: &retryTestHandler{}, + } + second := &retryTestConnector{ + server: "srv1", + handler: &retryTestHandler{}, + } + + originalSpec := SessionSpec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + updatedSpec := SessionSpec{ + Mode: omode.TailClient, + Files: []string{"/var/log/next.log"}, + Regex: "WARN", + } + + sleepCalls := 0 + var capturedSpec SessionSpec + client := &baseClient{ + retry: true, + sessionSpec: originalSpec, + stats: &stats{ + connectionsEstCh: make(chan struct{}, 1), + }, + connections: []connectors.Connector{first}, + connectionFactory: func(server string, _ []gossh.AuthMethod, + _ sshclient.HostKeyCallback, sessionSpec SessionSpec, _ bool) connectors.Connector { + if server != "srv1" { + t.Fatalf("unexpected reconnect server %q", server) + } + capturedSpec = sessionSpec + return second + }, + } + client.sleepFn = func(context.Context, time.Duration) bool { + if sleepCalls == 0 { + sleepCalls++ + client.sessionSpec = updatedSpec + return true + } + return false + } + + status := client.startConnection(context.Background(), 0, first) + if status != 0 { + t.Fatalf("startConnection() status = %d, want 0", status) + } + if capturedSpec.Regex != updatedSpec.Regex || len(capturedSpec.Files) != 1 || capturedSpec.Files[0] != updatedSpec.Files[0] { + t.Fatalf("reconnect used stale session spec: got %#v want %#v", capturedSpec, updatedSpec) + } + if client.connections[0] != second { + t.Fatalf("expected retried connector to replace the original connection") + } +} + +type retryTestConnector struct { + handler handlers.Handler + server string +} + +func (c *retryTestConnector) Start(context.Context, context.CancelFunc, chan struct{}, chan struct{}) { +} + +func (c *retryTestConnector) Server() string { return c.server } + +func (c *retryTestConnector) Handler() handlers.Handler { return c.handler } + +func (*retryTestConnector) SupportsQueryUpdates(time.Duration) bool { return false } + +func (*retryTestConnector) ApplySessionSpec(SessionSpec, time.Duration) error { return nil } + +func (*retryTestConnector) CommittedSession() (SessionSpec, uint64, bool) { + return SessionSpec{}, 0, false +} + +type retryTestHandler struct{} + +func (*retryTestHandler) Read([]byte) (int, error) { return 0, nil } + +func (*retryTestHandler) Write(p []byte) (int, error) { return len(p), nil } + +func (*retryTestHandler) Capabilities() []string { return nil } + +func (*retryTestHandler) HasCapability(string) bool { return false } + +func (*retryTestHandler) SendMessage(string) error { return nil } + +func (*retryTestHandler) Server() string { return "srv1" } + +func (*retryTestHandler) Status() int { return 0 } + +func (*retryTestHandler) Shutdown() {} + +func (*retryTestHandler) Done() <-chan struct{} { + done := make(chan struct{}) + close(done) + return done +} + +func (*retryTestHandler) WaitForCapabilities(time.Duration) bool { return false } + +func (*retryTestHandler) WaitForSessionAck(time.Duration) (handlers.SessionAck, bool) { + return handlers.SessionAck{}, false +} |
