diff options
Diffstat (limited to 'internal/clients/connectors')
| -rw-r--r-- | internal/clients/connectors/connector.go | 7 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection.go | 29 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection_test.go | 174 | ||||
| -rw-r--r-- | internal/clients/connectors/serverless.go | 27 | ||||
| -rw-r--r-- | internal/clients/connectors/sessiontransport.go | 154 |
5 files changed, 376 insertions, 15 deletions
diff --git a/internal/clients/connectors/connector.go b/internal/clients/connectors/connector.go index a803c33..c1211ec 100644 --- a/internal/clients/connectors/connector.go +++ b/internal/clients/connectors/connector.go @@ -5,6 +5,7 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" + sessionspec "github.com/mimecast/dtail/internal/session" ) // Connector interface. @@ -18,4 +19,10 @@ type Connector interface { // SupportsQueryUpdates reports whether the connected server advertised // runtime query replacement support within the given timeout. SupportsQueryUpdates(timeout time.Duration) bool + // ApplySessionSpec starts or updates the interactive session workload on an + // already connected server when query updates are supported. + ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error + // CommittedSession returns the last session spec and generation that the + // server acknowledged for this connection. + CommittedSession() (sessionspec.Spec, uint64, bool) } diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 97c02eb..5b432d4 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -14,6 +14,7 @@ import ( "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/protocol" + sessionspec "github.com/mimecast/dtail/internal/session" "github.com/mimecast/dtail/internal/ssh/client" "golang.org/x/crypto/ssh" @@ -43,6 +44,9 @@ type ServerConnection struct { config *ssh.ClientConfig handler handlers.Handler commands []string + sessionSpec sessionspec.Spec + sessionState committedSessionState + interactive bool authKeyPath string authKeyDisabled bool hostKeyCallback client.HostKeyCallback @@ -54,8 +58,8 @@ var _ Connector = (*ServerConnection)(nil) // NewServerConnection returns a new DTail SSH server connection. func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, - handler handlers.Handler, commands []string, authKeyPath string, - authKeyDisabled bool, settings SSHSettings) *ServerConnection { + handler handlers.Handler, commands []string, sessionSpec sessionspec.Spec, + interactive bool, authKeyPath string, authKeyDisabled bool, settings SSHSettings) *ServerConnection { dlog.Client.Debug(server, "Creating new connection", server, handler, commands) sshConnectTimeout := defaultSSHConnectTimeout @@ -76,6 +80,8 @@ func NewServerConnection(server string, userName string, server: server, handler: handler, commands: commands, + sessionSpec: sessionSpec, + interactive: interactive, authKeyPath: resolveAuthKeyPath(authKeyPath), authKeyDisabled: authKeyDisabled, config: &ssh.ClientConfig{ @@ -103,6 +109,17 @@ func (c *ServerConnection) SupportsQueryUpdates(timeout time.Duration) bool { return supportsQueryUpdates(c.handler, timeout) } +// ApplySessionSpec starts or updates the interactive session state on the +// existing SSH connection when runtime query updates are supported. +func (c *ServerConnection) ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error { + return applySessionSpec(c.server, c.handler, &c.sessionState, spec, timeout) +} + +// CommittedSession returns the last server-acknowledged session state. +func (c *ServerConnection) CommittedSession() (sessionspec.Spec, uint64, bool) { + return c.sessionState.snapshot() +} + // Attempt to parse the server port address from the provided server FQDN. func (c *ServerConnection) initServerPort(defaultPort int) { parts := strings.Split(c.server, ":") @@ -257,12 +274,8 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc c.sendAuthKeyRegistrationCommand() } - // Send all requested commands to the server. - for _, command := range c.commands { - dlog.Client.Debug(command) - if err := c.handler.SendMessage(command); err != nil { - dlog.Client.Debug(err) - } + if err := dispatchInitialCommands(c.server, c.handler, c.commands, c.interactive, c.sessionSpec, &c.sessionState); err != nil { + return err } if !c.throttlingDone { diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go index 9307b24..76c4eb6 100644 --- a/internal/clients/connectors/serverconnection_test.go +++ b/internal/clients/connectors/serverconnection_test.go @@ -2,6 +2,7 @@ package connectors import ( "context" + "errors" "os" "path/filepath" "testing" @@ -9,7 +10,9 @@ import ( "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/protocol" + sessionspec "github.com/mimecast/dtail/internal/session" "golang.org/x/crypto/ssh" ) @@ -91,6 +94,8 @@ func TestNewServerConnectionUsesInjectedSettings(t *testing.T) { testHostKeyCallback{}, &mockHandler{}, nil, + sessionspec.Spec{}, + false, "", false, testSSHSettings{port: 3022, timeout: 5 * time.Second}, @@ -117,6 +122,8 @@ func TestNewServerConnectionFallsBackToDefaults(t *testing.T) { testHostKeyCallback{}, &mockHandler{}, nil, + sessionspec.Spec{}, + false, "", false, testSSHSettings{}, @@ -173,6 +180,159 @@ func TestServerConnectionSupportsQueryUpdatesRequiresCapabilityFlag(t *testing.T } } +func TestServerConnectionApplySessionSpecStart(t *testing.T) { + resetClientLogger(t) + + conn := &ServerConnection{ + server: "srv1", + handler: &mockHandler{ + waitForCapabilities: true, + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + sessionAcks: []handlers.SessionAck{{ + Action: "start", + Generation: 1, + }}, + }, + } + + spec := sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + if err := conn.ApplySessionSpec(spec, 10*time.Millisecond); err != nil { + t.Fatalf("ApplySessionSpec() error = %v", err) + } + + mock := conn.handler.(*mockHandler) + if len(mock.commands) != 1 { + t.Fatalf("expected one session command, got %d", len(mock.commands)) + } + if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 1 || committedSpec.Regex != "ERROR" { + t.Fatalf("unexpected committed session: spec=%#v generation=%d ok=%v", committedSpec, generation, ok) + } +} + +func TestServerConnectionApplySessionSpecUpdateUsesNextGeneration(t *testing.T) { + resetClientLogger(t) + + mock := &mockHandler{ + waitForCapabilities: true, + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + sessionAcks: []handlers.SessionAck{ + {Action: "start", Generation: 4}, + {Action: "update", Generation: 5}, + }, + } + conn := &ServerConnection{ + server: "srv1", + handler: mock, + } + + startSpec := sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + updateSpec := sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "WARN", + } + + if err := conn.ApplySessionSpec(startSpec, 10*time.Millisecond); err != nil { + t.Fatalf("start ApplySessionSpec() error = %v", err) + } + if err := conn.ApplySessionSpec(updateSpec, 10*time.Millisecond); err != nil { + t.Fatalf("update ApplySessionSpec() error = %v", err) + } + if len(mock.commands) != 2 { + t.Fatalf("expected two session commands, got %d", len(mock.commands)) + } + if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 5 || committedSpec.Regex != "WARN" { + t.Fatalf("unexpected committed session after update: spec=%#v generation=%d ok=%v", committedSpec, generation, ok) + } +} + +func TestServerConnectionApplySessionSpecFallsBackForUnsupportedServer(t *testing.T) { + resetClientLogger(t) + + conn := &ServerConnection{ + handler: &mockHandler{}, + } + + err := conn.ApplySessionSpec(sessionspec.Spec{Mode: omode.TailClient, Regex: "ERROR"}, 5*time.Millisecond) + if !errors.Is(err, ErrSessionUnsupported) { + t.Fatalf("expected ErrSessionUnsupported, got %v", err) + } +} + +func TestServerConnectionApplySessionSpecPreservesCommittedStateOnRejectedUpdate(t *testing.T) { + resetClientLogger(t) + + mock := &mockHandler{ + waitForCapabilities: true, + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + sessionAcks: []handlers.SessionAck{ + {Action: "start", Generation: 2}, + {Action: "error", Error: "bad reload"}, + }, + } + conn := &ServerConnection{ + server: "srv1", + handler: mock, + } + + startSpec := sessionspec.Spec{Mode: omode.TailClient, Regex: "ERROR"} + if err := conn.ApplySessionSpec(startSpec, 10*time.Millisecond); err != nil { + t.Fatalf("start ApplySessionSpec() error = %v", err) + } + + err := conn.ApplySessionSpec(sessionspec.Spec{Mode: omode.TailClient, Regex: "WARN"}, 10*time.Millisecond) + if !errors.Is(err, ErrSessionRejected) { + t.Fatalf("expected ErrSessionRejected, got %v", err) + } + if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 2 || committedSpec.Regex != "ERROR" { + t.Fatalf("unexpected committed session after rejected update: spec=%#v generation=%d ok=%v", committedSpec, generation, ok) + } +} + +func TestServerConnectionApplySessionSpecRejectsUnexpectedAck(t *testing.T) { + resetClientLogger(t) + + mock := &mockHandler{ + waitForCapabilities: true, + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + sessionAcks: []handlers.SessionAck{ + {Action: "update", Generation: 1}, + }, + } + conn := &ServerConnection{ + server: "srv1", + handler: mock, + } + + err := conn.ApplySessionSpec(sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + }, 10*time.Millisecond) + if !errors.Is(err, ErrUnexpectedSessionAck) { + t.Fatalf("expected ErrUnexpectedSessionAck, got %v", err) + } + if _, _, ok := conn.CommittedSession(); ok { + t.Fatalf("unexpected committed session after mismatched ack") + } +} + type testSSHSettings struct { port int timeout time.Duration @@ -212,6 +372,7 @@ type mockHandler struct { commands []string capabilities map[string]bool waitForCapabilities bool + sessionAcks []handlers.SessionAck } var _ handlers.Handler = (*mockHandler)(nil) @@ -253,6 +414,19 @@ func (m *mockHandler) WaitForCapabilities(timeout time.Duration) bool { return m.waitForCapabilities } +func (m *mockHandler) WaitForSessionAck(timeout time.Duration) (handlers.SessionAck, bool) { + if timeout <= 0 { + return handlers.SessionAck{}, false + } + if len(m.sessionAcks) == 0 { + return handlers.SessionAck{}, false + } + + ack := m.sessionAcks[0] + m.sessionAcks = m.sessionAcks[1:] + return ack, true +} + func (m *mockHandler) Read(_ []byte) (int, error) { return 0, nil } diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go index 0ebe069..72e3fda 100644 --- a/internal/clients/connectors/serverless.go +++ b/internal/clients/connectors/serverless.go @@ -8,6 +8,7 @@ import ( "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/io/dlog" serverHandlers "github.com/mimecast/dtail/internal/server/handlers" + sessionspec "github.com/mimecast/dtail/internal/session" ) // ServerlessHandlerFactory creates the in-process server-side handler used by serverless mode. @@ -19,6 +20,9 @@ type ServerlessHandlerFactory interface { type Serverless struct { handler handlers.Handler commands []string + sessionSpec sessionspec.Spec + sessionState committedSessionState + interactive bool userName string handlerFactory ServerlessHandlerFactory } @@ -27,13 +31,16 @@ var _ Connector = (*Serverless)(nil) // NewServerless starts a new serverless session. func NewServerless(userName string, handler handlers.Handler, - commands []string, handlerFactory ServerlessHandlerFactory) *Serverless { + commands []string, sessionSpec sessionspec.Spec, interactive bool, + handlerFactory ServerlessHandlerFactory) *Serverless { dlog.Client.Debug("Creating new serverless connector", handler, commands) return &Serverless{ userName: userName, handler: handler, commands: commands, + sessionSpec: sessionSpec, + interactive: interactive, handlerFactory: handlerFactory, } } @@ -54,6 +61,16 @@ func (s *Serverless) SupportsQueryUpdates(timeout time.Duration) bool { return supportsQueryUpdates(s.handler, timeout) } +// ApplySessionSpec starts or updates the in-process interactive session state. +func (s *Serverless) ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error { + return applySessionSpec(s.Server(), s.handler, &s.sessionState, spec, timeout) +} + +// CommittedSession returns the last server-acknowledged session state. +func (s *Serverless) CommittedSession() (sessionspec.Spec, uint64, bool) { + return s.sessionState.snapshot() +} + // Start the serverless connection. func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { @@ -165,12 +182,8 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro } }() - // Send commands after setting up the data flow - for _, command := range s.commands { - dlog.Client.Debug("Sending command to serverless server", command) - if err := s.handler.SendMessage(command); err != nil { - dlog.Client.Debug(err) - } + if err := dispatchInitialCommands(s.Server(), s.handler, s.commands, s.interactive, s.sessionSpec, &s.sessionState); err != nil { + return err } // Monitor for completion diff --git a/internal/clients/connectors/sessiontransport.go b/internal/clients/connectors/sessiontransport.go new file mode 100644 index 0000000..84aeb78 --- /dev/null +++ b/internal/clients/connectors/sessiontransport.go @@ -0,0 +1,154 @@ +package connectors + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/omode" + sessionspec "github.com/mimecast/dtail/internal/session" +) + +var ( + // ErrSessionUnsupported indicates that the remote side did not advertise + // runtime query update support. + ErrSessionUnsupported = errors.New("runtime query updates unsupported by server") + // ErrSessionAckTimeout indicates that no hidden SESSION acknowledgement arrived in time. + ErrSessionAckTimeout = errors.New("timed out waiting for session acknowledgement") + // ErrSessionRejected indicates that the server explicitly rejected a SESSION request. + ErrSessionRejected = errors.New("session request rejected") + // ErrUnexpectedSessionAck indicates that the client received a malformed or mismatched acknowledgement. + ErrUnexpectedSessionAck = errors.New("unexpected session acknowledgement") +) + +const defaultSessionAckTimeout = 2 * time.Second + +type committedSessionState struct { + mu sync.RWMutex + committed bool + generation uint64 + spec sessionspec.Spec +} + +func (s *committedSessionState) commit(spec sessionspec.Spec, generation uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.committed = true + s.generation = generation + s.spec = spec +} + +func (s *committedSessionState) clear() { + s.mu.Lock() + defer s.mu.Unlock() + + s.committed = false + s.generation = 0 + s.spec = sessionspec.Spec{} +} + +func (s *committedSessionState) snapshot() (sessionspec.Spec, uint64, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.spec, s.generation, s.committed +} + +func dispatchInitialCommands(server string, handler handlers.Handler, commands []string, + interactiveQuery bool, initialSpec sessionspec.Spec, state *committedSessionState) error { + + if !interactiveQuery || initialSpec.Mode == omode.Unknown { + return sendLegacyCommands(handler, commands) + } + + if err := applySessionSpec(server, handler, state, initialSpec, defaultSessionAckTimeout); err != nil { + if !errors.Is(err, ErrSessionUnsupported) { + dlog.Client.Warn(server, "Interactive session bootstrap failed, falling back to legacy commands", err) + } + state.clear() + return sendLegacyCommands(handler, commands) + } + + return nil +} + +func applySessionSpec(server string, handler handlers.Handler, + state *committedSessionState, spec sessionspec.Spec, timeout time.Duration) error { + + if handler == nil { + return ErrSessionUnsupported + } + if !supportsQueryUpdates(handler, defaultCapabilityWait) { + return ErrSessionUnsupported + } + + action := "start" + nextGeneration := uint64(0) + command, err := spec.StartCommand() + if err != nil { + return err + } + + if _, generation, ok := state.snapshot(); ok { + action = "update" + nextGeneration = generation + 1 + command, err = spec.UpdateCommand(nextGeneration) + if err != nil { + return err + } + } + + drainSessionAcks(handler) + if err := handler.SendMessage(command); err != nil { + return err + } + + ack, ok := handler.WaitForSessionAck(resolveSessionAckTimeout(timeout)) + if !ok { + return ErrSessionAckTimeout + } + if ack.Error != "" { + return fmt.Errorf("%w: %s", ErrSessionRejected, ack.Error) + } + if ack.Action != action { + return fmt.Errorf("%w: got action %q want %q", ErrUnexpectedSessionAck, ack.Action, action) + } + if ack.Generation == 0 { + return fmt.Errorf("%w: missing generation", ErrUnexpectedSessionAck) + } + if action == "update" && ack.Generation != nextGeneration { + return fmt.Errorf("%w: got generation %d want %d", ErrUnexpectedSessionAck, ack.Generation, nextGeneration) + } + + state.commit(spec, ack.Generation) + dlog.Client.Debug(server, "Committed session spec", "action", action, "generation", ack.Generation) + return nil +} + +func sendLegacyCommands(handler handlers.Handler, commands []string) error { + for _, command := range commands { + if err := handler.SendMessage(command); err != nil { + return err + } + } + return nil +} + +func drainSessionAcks(handler handlers.Handler) { + for { + if _, ok := handler.WaitForSessionAck(0); !ok { + return + } + } +} + +func resolveSessionAckTimeout(timeout time.Duration) time.Duration { + if timeout <= 0 { + return defaultSessionAckTimeout + } + return timeout +} |
