diff options
Diffstat (limited to 'internal/clients')
| -rw-r--r-- | internal/clients/connectors/serverconnection_test.go | 148 | ||||
| -rw-r--r-- | internal/clients/connectors/sessiontransport.go | 6 |
2 files changed, 154 insertions, 0 deletions
diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go index 01fe4af..0cb15c2 100644 --- a/internal/clients/connectors/serverconnection_test.go +++ b/internal/clients/connectors/serverconnection_test.go @@ -5,6 +5,8 @@ import ( "errors" "os" "path/filepath" + "strings" + "sync" "testing" "time" @@ -363,6 +365,65 @@ func TestServerConnectionApplySessionSpecTimesOutWaitingForAck(t *testing.T) { } } +func TestApplySessionSpecSerializesConcurrentBootstrapAndReload(t *testing.T) { + resetClientLogger(t) + + handler := newBlockingSessionHandler() + state := &committedSessionState{} + + initialSpec := sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + reloadSpec := sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "WARN", + } + + initialErrCh := make(chan error, 1) + go func() { + initialErrCh <- dispatchInitialCommands("srv1", handler, nil, true, initialSpec, state) + }() + + firstCommand := <-handler.commandsCh + if !strings.HasPrefix(firstCommand, "SESSION START ") { + t.Fatalf("expected initial SESSION START command, got %q", firstCommand) + } + + reloadErrCh := make(chan error, 1) + go func() { + reloadErrCh <- applySessionSpec("srv1", handler, state, reloadSpec, 50*time.Millisecond) + }() + + select { + case command := <-handler.commandsCh: + t.Fatalf("unexpected concurrent session command before bootstrap ack: %q", command) + case <-time.After(10 * time.Millisecond): + } + + handler.ackCh <- handlers.SessionAck{Action: "start", Generation: 1} + if err := <-initialErrCh; err != nil { + t.Fatalf("dispatchInitialCommands() error = %v", err) + } + + secondCommand := <-handler.commandsCh + if !strings.HasPrefix(secondCommand, "SESSION UPDATE 2 ") { + t.Fatalf("expected reload to send SESSION UPDATE after bootstrap, got %q", secondCommand) + } + + handler.ackCh <- handlers.SessionAck{Action: "update", Generation: 2} + if err := <-reloadErrCh; err != nil { + t.Fatalf("applySessionSpec() error = %v", err) + } + + committedSpec, generation, ok := state.snapshot() + if !ok || generation != 2 || committedSpec.Regex != "WARN" { + t.Fatalf("unexpected committed session after reload: spec=%#v generation=%d ok=%v", committedSpec, generation, ok) + } +} + type testSSHSettings struct { port int timeout time.Duration @@ -464,3 +525,90 @@ func (m *mockHandler) Read(_ []byte) (int, error) { func (m *mockHandler) Write(p []byte) (int, error) { return len(p), nil } + +type blockingSessionHandler struct { + mu sync.Mutex + commands []string + commandsCh chan string + ackCh chan handlers.SessionAck + capabilities map[string]bool +} + +func newBlockingSessionHandler() *blockingSessionHandler { + return &blockingSessionHandler{ + commandsCh: make(chan string, 8), + ackCh: make(chan handlers.SessionAck, 8), + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + } +} + +var _ handlers.Handler = (*blockingSessionHandler)(nil) + +func (h *blockingSessionHandler) SendMessage(command string) error { + h.mu.Lock() + h.commands = append(h.commands, command) + h.mu.Unlock() + h.commandsCh <- command + return nil +} + +func (h *blockingSessionHandler) Capabilities() []string { + capabilities := make([]string, 0, len(h.capabilities)) + for capability := range h.capabilities { + capabilities = append(capabilities, capability) + } + return capabilities +} + +func (h *blockingSessionHandler) HasCapability(name string) bool { + return h.capabilities[name] +} + +func (*blockingSessionHandler) Server() string { + return "mock" +} + +func (*blockingSessionHandler) Status() int { + return 0 +} + +func (*blockingSessionHandler) Shutdown() {} + +func (*blockingSessionHandler) Done() <-chan struct{} { + return make(chan struct{}) +} + +func (*blockingSessionHandler) WaitForCapabilities(time.Duration) bool { + return true +} + +func (h *blockingSessionHandler) WaitForSessionAck(timeout time.Duration) (handlers.SessionAck, bool) { + if timeout <= 0 { + select { + case ack := <-h.ackCh: + return ack, true + default: + return handlers.SessionAck{}, false + } + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case ack := <-h.ackCh: + return ack, true + case <-timer.C: + return handlers.SessionAck{}, false + } +} + +func (*blockingSessionHandler) Read(_ []byte) (int, error) { + return 0, nil +} + +func (*blockingSessionHandler) Write(p []byte) (int, error) { + return len(p), nil +} diff --git a/internal/clients/connectors/sessiontransport.go b/internal/clients/connectors/sessiontransport.go index 84aeb78..428752a 100644 --- a/internal/clients/connectors/sessiontransport.go +++ b/internal/clients/connectors/sessiontransport.go @@ -27,6 +27,7 @@ var ( const defaultSessionAckTimeout = 2 * time.Second type committedSessionState struct { + applyMu sync.Mutex mu sync.RWMutex committed bool generation uint64 @@ -79,6 +80,11 @@ func dispatchInitialCommands(server string, handler handlers.Handler, commands [ func applySessionSpec(server string, handler handlers.Handler, state *committedSessionState, spec sessionspec.Spec, timeout time.Duration) error { + // Serialize session transitions so an interactive reload cannot race the + // initial SESSION START bootstrap on the same connection. + state.applyMu.Lock() + defer state.applyMu.Unlock() + if handler == nil { return ErrSessionUnsupported } |
