summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-13 09:54:31 +0200
committerPaul Buetow <paul@buetow.org>2026-03-13 09:54:31 +0200
commita5a405d79fe3d9e0c6ea081b425d36bd67d8671d (patch)
treeff460d868a75c901bd9ea3bfbc3b0af738723a59 /internal
parent7ce4aee16c0223cc15c9dd8d9024120069500c65 (diff)
task be5429a7: cover reconnect session restore
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/baseclient.go19
-rw-r--r--internal/clients/baseclient_retry_test.go124
2 files changed, 142 insertions, 1 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 0e67ba4..71b8d02 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -42,6 +42,12 @@ type baseClient struct {
sessionSpec SessionSpec
// Connection maker helper.
maker maker
+ // Optional factory override for retry/reconnect tests.
+ connectionFactory func(server string, authMethods []gossh.AuthMethod,
+ hostKeyCallback client.HostKeyCallback, sessionSpec SessionSpec,
+ interactive bool) connectors.Connector
+ // Optional sleep override for retry tests.
+ sleepFn func(context.Context, time.Duration) bool
// Regex is the regular expresion object for line filtering
Regex regex.Regex
}
@@ -155,7 +161,7 @@ func (c *baseClient) startConnection(ctx context.Context, i int,
// Yes, we want to retry with exponential backoff and jitter.
sleepDuration := jitterRetryDelay(retryDelay, retryRandom)
dlog.Client.Debug(conn.Server(), "Reconnecting", "backoff", sleepDuration)
- if !sleepWithContext(ctx, sleepDuration) {
+ if !c.sleepRetry(ctx, sleepDuration) {
return
}
@@ -218,6 +224,10 @@ func newRetryRandom(seedOffset int) *rand.Rand {
func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod,
hostKeyCallback client.HostKeyCallback) connectors.Connector {
+ if c.connectionFactory != nil {
+ return c.connectionFactory(server, sshAuthMethods, hostKeyCallback,
+ c.sessionSpec, c.Args.InteractiveQuery)
+ }
if c.Args.Serverless {
return connectors.NewServerless(c.UserName, c.maker.makeHandler(server),
c.maker.makeCommands(), c.sessionSpec, c.Args.InteractiveQuery, c.runtime)
@@ -227,3 +237,10 @@ func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMe
c.sessionSpec, c.Args.InteractiveQuery, c.Args.SSHPrivateKeyFilePath,
c.Args.NoAuthKey, c.runtime)
}
+
+func (c *baseClient) sleepRetry(ctx context.Context, delay time.Duration) bool {
+ if c.sleepFn != nil {
+ return c.sleepFn(ctx, delay)
+ }
+ return sleepWithContext(ctx, delay)
+}
diff --git a/internal/clients/baseclient_retry_test.go b/internal/clients/baseclient_retry_test.go
index 323ceae..7a3c571 100644
--- a/internal/clients/baseclient_retry_test.go
+++ b/internal/clients/baseclient_retry_test.go
@@ -5,6 +5,14 @@ import (
"math/rand"
"testing"
"time"
+
+ "github.com/mimecast/dtail/internal/clients/connectors"
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/omode"
+ sshclient "github.com/mimecast/dtail/internal/ssh/client"
+
+ gossh "golang.org/x/crypto/ssh"
)
func TestNextRetryDelay(t *testing.T) {
@@ -56,3 +64,119 @@ func TestSleepWithContextCancellation(t *testing.T) {
t.Fatalf("sleepWithContext took too long to exit on canceled context")
}
}
+
+func TestStartConnectionReconnectsWithLatestSessionSpec(t *testing.T) {
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ dlog.Client = originalLogger
+ })
+
+ first := &retryTestConnector{
+ server: "srv1",
+ handler: &retryTestHandler{},
+ }
+ second := &retryTestConnector{
+ server: "srv1",
+ handler: &retryTestHandler{},
+ }
+
+ originalSpec := SessionSpec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+ updatedSpec := SessionSpec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/next.log"},
+ Regex: "WARN",
+ }
+
+ sleepCalls := 0
+ var capturedSpec SessionSpec
+ client := &baseClient{
+ retry: true,
+ sessionSpec: originalSpec,
+ stats: &stats{
+ connectionsEstCh: make(chan struct{}, 1),
+ },
+ connections: []connectors.Connector{first},
+ connectionFactory: func(server string, _ []gossh.AuthMethod,
+ _ sshclient.HostKeyCallback, sessionSpec SessionSpec, _ bool) connectors.Connector {
+ if server != "srv1" {
+ t.Fatalf("unexpected reconnect server %q", server)
+ }
+ capturedSpec = sessionSpec
+ return second
+ },
+ }
+ client.sleepFn = func(context.Context, time.Duration) bool {
+ if sleepCalls == 0 {
+ sleepCalls++
+ client.sessionSpec = updatedSpec
+ return true
+ }
+ return false
+ }
+
+ status := client.startConnection(context.Background(), 0, first)
+ if status != 0 {
+ t.Fatalf("startConnection() status = %d, want 0", status)
+ }
+ if capturedSpec.Regex != updatedSpec.Regex || len(capturedSpec.Files) != 1 || capturedSpec.Files[0] != updatedSpec.Files[0] {
+ t.Fatalf("reconnect used stale session spec: got %#v want %#v", capturedSpec, updatedSpec)
+ }
+ if client.connections[0] != second {
+ t.Fatalf("expected retried connector to replace the original connection")
+ }
+}
+
+type retryTestConnector struct {
+ handler handlers.Handler
+ server string
+}
+
+func (c *retryTestConnector) Start(context.Context, context.CancelFunc, chan struct{}, chan struct{}) {
+}
+
+func (c *retryTestConnector) Server() string { return c.server }
+
+func (c *retryTestConnector) Handler() handlers.Handler { return c.handler }
+
+func (*retryTestConnector) SupportsQueryUpdates(time.Duration) bool { return false }
+
+func (*retryTestConnector) ApplySessionSpec(SessionSpec, time.Duration) error { return nil }
+
+func (*retryTestConnector) CommittedSession() (SessionSpec, uint64, bool) {
+ return SessionSpec{}, 0, false
+}
+
+type retryTestHandler struct{}
+
+func (*retryTestHandler) Read([]byte) (int, error) { return 0, nil }
+
+func (*retryTestHandler) Write(p []byte) (int, error) { return len(p), nil }
+
+func (*retryTestHandler) Capabilities() []string { return nil }
+
+func (*retryTestHandler) HasCapability(string) bool { return false }
+
+func (*retryTestHandler) SendMessage(string) error { return nil }
+
+func (*retryTestHandler) Server() string { return "srv1" }
+
+func (*retryTestHandler) Status() int { return 0 }
+
+func (*retryTestHandler) Shutdown() {}
+
+func (*retryTestHandler) Done() <-chan struct{} {
+ done := make(chan struct{})
+ close(done)
+ return done
+}
+
+func (*retryTestHandler) WaitForCapabilities(time.Duration) bool { return false }
+
+func (*retryTestHandler) WaitForSessionAck(time.Duration) (handlers.SessionAck, bool) {
+ return handlers.SessionAck{}, false
+}