summaryrefslogtreecommitdiff
path: root/internal/clients
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients')
-rw-r--r--internal/clients/connectors/serverconnection_test.go148
-rw-r--r--internal/clients/connectors/sessiontransport.go6
2 files changed, 154 insertions, 0 deletions
diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go
index 01fe4af..0cb15c2 100644
--- a/internal/clients/connectors/serverconnection_test.go
+++ b/internal/clients/connectors/serverconnection_test.go
@@ -5,6 +5,8 @@ import (
"errors"
"os"
"path/filepath"
+ "strings"
+ "sync"
"testing"
"time"
@@ -363,6 +365,65 @@ func TestServerConnectionApplySessionSpecTimesOutWaitingForAck(t *testing.T) {
}
}
+func TestApplySessionSpecSerializesConcurrentBootstrapAndReload(t *testing.T) {
+ resetClientLogger(t)
+
+ handler := newBlockingSessionHandler()
+ state := &committedSessionState{}
+
+ initialSpec := sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+ reloadSpec := sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "WARN",
+ }
+
+ initialErrCh := make(chan error, 1)
+ go func() {
+ initialErrCh <- dispatchInitialCommands("srv1", handler, nil, true, initialSpec, state)
+ }()
+
+ firstCommand := <-handler.commandsCh
+ if !strings.HasPrefix(firstCommand, "SESSION START ") {
+ t.Fatalf("expected initial SESSION START command, got %q", firstCommand)
+ }
+
+ reloadErrCh := make(chan error, 1)
+ go func() {
+ reloadErrCh <- applySessionSpec("srv1", handler, state, reloadSpec, 50*time.Millisecond)
+ }()
+
+ select {
+ case command := <-handler.commandsCh:
+ t.Fatalf("unexpected concurrent session command before bootstrap ack: %q", command)
+ case <-time.After(10 * time.Millisecond):
+ }
+
+ handler.ackCh <- handlers.SessionAck{Action: "start", Generation: 1}
+ if err := <-initialErrCh; err != nil {
+ t.Fatalf("dispatchInitialCommands() error = %v", err)
+ }
+
+ secondCommand := <-handler.commandsCh
+ if !strings.HasPrefix(secondCommand, "SESSION UPDATE 2 ") {
+ t.Fatalf("expected reload to send SESSION UPDATE after bootstrap, got %q", secondCommand)
+ }
+
+ handler.ackCh <- handlers.SessionAck{Action: "update", Generation: 2}
+ if err := <-reloadErrCh; err != nil {
+ t.Fatalf("applySessionSpec() error = %v", err)
+ }
+
+ committedSpec, generation, ok := state.snapshot()
+ if !ok || generation != 2 || committedSpec.Regex != "WARN" {
+ t.Fatalf("unexpected committed session after reload: spec=%#v generation=%d ok=%v", committedSpec, generation, ok)
+ }
+}
+
type testSSHSettings struct {
port int
timeout time.Duration
@@ -464,3 +525,90 @@ func (m *mockHandler) Read(_ []byte) (int, error) {
func (m *mockHandler) Write(p []byte) (int, error) {
return len(p), nil
}
+
+type blockingSessionHandler struct {
+ mu sync.Mutex
+ commands []string
+ commandsCh chan string
+ ackCh chan handlers.SessionAck
+ capabilities map[string]bool
+}
+
+func newBlockingSessionHandler() *blockingSessionHandler {
+ return &blockingSessionHandler{
+ commandsCh: make(chan string, 8),
+ ackCh: make(chan handlers.SessionAck, 8),
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ }
+}
+
+var _ handlers.Handler = (*blockingSessionHandler)(nil)
+
+func (h *blockingSessionHandler) SendMessage(command string) error {
+ h.mu.Lock()
+ h.commands = append(h.commands, command)
+ h.mu.Unlock()
+ h.commandsCh <- command
+ return nil
+}
+
+func (h *blockingSessionHandler) Capabilities() []string {
+ capabilities := make([]string, 0, len(h.capabilities))
+ for capability := range h.capabilities {
+ capabilities = append(capabilities, capability)
+ }
+ return capabilities
+}
+
+func (h *blockingSessionHandler) HasCapability(name string) bool {
+ return h.capabilities[name]
+}
+
+func (*blockingSessionHandler) Server() string {
+ return "mock"
+}
+
+func (*blockingSessionHandler) Status() int {
+ return 0
+}
+
+func (*blockingSessionHandler) Shutdown() {}
+
+func (*blockingSessionHandler) Done() <-chan struct{} {
+ return make(chan struct{})
+}
+
+func (*blockingSessionHandler) WaitForCapabilities(time.Duration) bool {
+ return true
+}
+
+func (h *blockingSessionHandler) WaitForSessionAck(timeout time.Duration) (handlers.SessionAck, bool) {
+ if timeout <= 0 {
+ select {
+ case ack := <-h.ackCh:
+ return ack, true
+ default:
+ return handlers.SessionAck{}, false
+ }
+ }
+
+ timer := time.NewTimer(timeout)
+ defer timer.Stop()
+
+ select {
+ case ack := <-h.ackCh:
+ return ack, true
+ case <-timer.C:
+ return handlers.SessionAck{}, false
+ }
+}
+
+func (*blockingSessionHandler) Read(_ []byte) (int, error) {
+ return 0, nil
+}
+
+func (*blockingSessionHandler) Write(p []byte) (int, error) {
+ return len(p), nil
+}
diff --git a/internal/clients/connectors/sessiontransport.go b/internal/clients/connectors/sessiontransport.go
index 84aeb78..428752a 100644
--- a/internal/clients/connectors/sessiontransport.go
+++ b/internal/clients/connectors/sessiontransport.go
@@ -27,6 +27,7 @@ var (
const defaultSessionAckTimeout = 2 * time.Second
type committedSessionState struct {
+ applyMu sync.Mutex
mu sync.RWMutex
committed bool
generation uint64
@@ -79,6 +80,11 @@ func dispatchInitialCommands(server string, handler handlers.Handler, commands [
func applySessionSpec(server string, handler handlers.Handler,
state *committedSessionState, spec sessionspec.Spec, timeout time.Duration) error {
+ // Serialize session transitions so an interactive reload cannot race the
+ // initial SESSION START bootstrap on the same connection.
+ state.applyMu.Lock()
+ defer state.applyMu.Unlock()
+
if handler == nil {
return ErrSessionUnsupported
}