summaryrefslogtreecommitdiff
path: root/internal/clients/connectors
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients/connectors')
-rw-r--r--internal/clients/connectors/connector.go7
-rw-r--r--internal/clients/connectors/serverconnection.go29
-rw-r--r--internal/clients/connectors/serverconnection_test.go174
-rw-r--r--internal/clients/connectors/serverless.go27
-rw-r--r--internal/clients/connectors/sessiontransport.go154
5 files changed, 376 insertions, 15 deletions
diff --git a/internal/clients/connectors/connector.go b/internal/clients/connectors/connector.go
index a803c33..c1211ec 100644
--- a/internal/clients/connectors/connector.go
+++ b/internal/clients/connectors/connector.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/mimecast/dtail/internal/clients/handlers"
+ sessionspec "github.com/mimecast/dtail/internal/session"
)
// Connector interface.
@@ -18,4 +19,10 @@ type Connector interface {
// SupportsQueryUpdates reports whether the connected server advertised
// runtime query replacement support within the given timeout.
SupportsQueryUpdates(timeout time.Duration) bool
+ // ApplySessionSpec starts or updates the interactive session workload on an
+ // already connected server when query updates are supported.
+ ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error
+ // CommittedSession returns the last session spec and generation that the
+ // server acknowledged for this connection.
+ CommittedSession() (sessionspec.Spec, uint64, bool)
}
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 97c02eb..5b432d4 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -14,6 +14,7 @@ import (
"github.com/mimecast/dtail/internal/clients/handlers"
"github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/protocol"
+ sessionspec "github.com/mimecast/dtail/internal/session"
"github.com/mimecast/dtail/internal/ssh/client"
"golang.org/x/crypto/ssh"
@@ -43,6 +44,9 @@ type ServerConnection struct {
config *ssh.ClientConfig
handler handlers.Handler
commands []string
+ sessionSpec sessionspec.Spec
+ sessionState committedSessionState
+ interactive bool
authKeyPath string
authKeyDisabled bool
hostKeyCallback client.HostKeyCallback
@@ -54,8 +58,8 @@ var _ Connector = (*ServerConnection)(nil)
// NewServerConnection returns a new DTail SSH server connection.
func NewServerConnection(server string, userName string,
authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback,
- handler handlers.Handler, commands []string, authKeyPath string,
- authKeyDisabled bool, settings SSHSettings) *ServerConnection {
+ handler handlers.Handler, commands []string, sessionSpec sessionspec.Spec,
+ interactive bool, authKeyPath string, authKeyDisabled bool, settings SSHSettings) *ServerConnection {
dlog.Client.Debug(server, "Creating new connection", server, handler, commands)
sshConnectTimeout := defaultSSHConnectTimeout
@@ -76,6 +80,8 @@ func NewServerConnection(server string, userName string,
server: server,
handler: handler,
commands: commands,
+ sessionSpec: sessionSpec,
+ interactive: interactive,
authKeyPath: resolveAuthKeyPath(authKeyPath),
authKeyDisabled: authKeyDisabled,
config: &ssh.ClientConfig{
@@ -103,6 +109,17 @@ func (c *ServerConnection) SupportsQueryUpdates(timeout time.Duration) bool {
return supportsQueryUpdates(c.handler, timeout)
}
+// ApplySessionSpec starts or updates the interactive session state on the
+// existing SSH connection when runtime query updates are supported.
+func (c *ServerConnection) ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error {
+ return applySessionSpec(c.server, c.handler, &c.sessionState, spec, timeout)
+}
+
+// CommittedSession returns the last server-acknowledged session state.
+func (c *ServerConnection) CommittedSession() (sessionspec.Spec, uint64, bool) {
+ return c.sessionState.snapshot()
+}
+
// Attempt to parse the server port address from the provided server FQDN.
func (c *ServerConnection) initServerPort(defaultPort int) {
parts := strings.Split(c.server, ":")
@@ -257,12 +274,8 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc
c.sendAuthKeyRegistrationCommand()
}
- // Send all requested commands to the server.
- for _, command := range c.commands {
- dlog.Client.Debug(command)
- if err := c.handler.SendMessage(command); err != nil {
- dlog.Client.Debug(err)
- }
+ if err := dispatchInitialCommands(c.server, c.handler, c.commands, c.interactive, c.sessionSpec, &c.sessionState); err != nil {
+ return err
}
if !c.throttlingDone {
diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go
index 9307b24..76c4eb6 100644
--- a/internal/clients/connectors/serverconnection_test.go
+++ b/internal/clients/connectors/serverconnection_test.go
@@ -2,6 +2,7 @@ package connectors
import (
"context"
+ "errors"
"os"
"path/filepath"
"testing"
@@ -9,7 +10,9 @@ import (
"github.com/mimecast/dtail/internal/clients/handlers"
"github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/protocol"
+ sessionspec "github.com/mimecast/dtail/internal/session"
"golang.org/x/crypto/ssh"
)
@@ -91,6 +94,8 @@ func TestNewServerConnectionUsesInjectedSettings(t *testing.T) {
testHostKeyCallback{},
&mockHandler{},
nil,
+ sessionspec.Spec{},
+ false,
"",
false,
testSSHSettings{port: 3022, timeout: 5 * time.Second},
@@ -117,6 +122,8 @@ func TestNewServerConnectionFallsBackToDefaults(t *testing.T) {
testHostKeyCallback{},
&mockHandler{},
nil,
+ sessionspec.Spec{},
+ false,
"",
false,
testSSHSettings{},
@@ -173,6 +180,159 @@ func TestServerConnectionSupportsQueryUpdatesRequiresCapabilityFlag(t *testing.T
}
}
+func TestServerConnectionApplySessionSpecStart(t *testing.T) {
+ resetClientLogger(t)
+
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: &mockHandler{
+ waitForCapabilities: true,
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ sessionAcks: []handlers.SessionAck{{
+ Action: "start",
+ Generation: 1,
+ }},
+ },
+ }
+
+ spec := sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+ if err := conn.ApplySessionSpec(spec, 10*time.Millisecond); err != nil {
+ t.Fatalf("ApplySessionSpec() error = %v", err)
+ }
+
+ mock := conn.handler.(*mockHandler)
+ if len(mock.commands) != 1 {
+ t.Fatalf("expected one session command, got %d", len(mock.commands))
+ }
+ if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 1 || committedSpec.Regex != "ERROR" {
+ t.Fatalf("unexpected committed session: spec=%#v generation=%d ok=%v", committedSpec, generation, ok)
+ }
+}
+
+func TestServerConnectionApplySessionSpecUpdateUsesNextGeneration(t *testing.T) {
+ resetClientLogger(t)
+
+ mock := &mockHandler{
+ waitForCapabilities: true,
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ sessionAcks: []handlers.SessionAck{
+ {Action: "start", Generation: 4},
+ {Action: "update", Generation: 5},
+ },
+ }
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: mock,
+ }
+
+ startSpec := sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+ updateSpec := sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "WARN",
+ }
+
+ if err := conn.ApplySessionSpec(startSpec, 10*time.Millisecond); err != nil {
+ t.Fatalf("start ApplySessionSpec() error = %v", err)
+ }
+ if err := conn.ApplySessionSpec(updateSpec, 10*time.Millisecond); err != nil {
+ t.Fatalf("update ApplySessionSpec() error = %v", err)
+ }
+ if len(mock.commands) != 2 {
+ t.Fatalf("expected two session commands, got %d", len(mock.commands))
+ }
+ if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 5 || committedSpec.Regex != "WARN" {
+ t.Fatalf("unexpected committed session after update: spec=%#v generation=%d ok=%v", committedSpec, generation, ok)
+ }
+}
+
+func TestServerConnectionApplySessionSpecFallsBackForUnsupportedServer(t *testing.T) {
+ resetClientLogger(t)
+
+ conn := &ServerConnection{
+ handler: &mockHandler{},
+ }
+
+ err := conn.ApplySessionSpec(sessionspec.Spec{Mode: omode.TailClient, Regex: "ERROR"}, 5*time.Millisecond)
+ if !errors.Is(err, ErrSessionUnsupported) {
+ t.Fatalf("expected ErrSessionUnsupported, got %v", err)
+ }
+}
+
+func TestServerConnectionApplySessionSpecPreservesCommittedStateOnRejectedUpdate(t *testing.T) {
+ resetClientLogger(t)
+
+ mock := &mockHandler{
+ waitForCapabilities: true,
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ sessionAcks: []handlers.SessionAck{
+ {Action: "start", Generation: 2},
+ {Action: "error", Error: "bad reload"},
+ },
+ }
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: mock,
+ }
+
+ startSpec := sessionspec.Spec{Mode: omode.TailClient, Regex: "ERROR"}
+ if err := conn.ApplySessionSpec(startSpec, 10*time.Millisecond); err != nil {
+ t.Fatalf("start ApplySessionSpec() error = %v", err)
+ }
+
+ err := conn.ApplySessionSpec(sessionspec.Spec{Mode: omode.TailClient, Regex: "WARN"}, 10*time.Millisecond)
+ if !errors.Is(err, ErrSessionRejected) {
+ t.Fatalf("expected ErrSessionRejected, got %v", err)
+ }
+ if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 2 || committedSpec.Regex != "ERROR" {
+ t.Fatalf("unexpected committed session after rejected update: spec=%#v generation=%d ok=%v", committedSpec, generation, ok)
+ }
+}
+
+func TestServerConnectionApplySessionSpecRejectsUnexpectedAck(t *testing.T) {
+ resetClientLogger(t)
+
+ mock := &mockHandler{
+ waitForCapabilities: true,
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ sessionAcks: []handlers.SessionAck{
+ {Action: "update", Generation: 1},
+ },
+ }
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: mock,
+ }
+
+ err := conn.ApplySessionSpec(sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }, 10*time.Millisecond)
+ if !errors.Is(err, ErrUnexpectedSessionAck) {
+ t.Fatalf("expected ErrUnexpectedSessionAck, got %v", err)
+ }
+ if _, _, ok := conn.CommittedSession(); ok {
+ t.Fatalf("unexpected committed session after mismatched ack")
+ }
+}
+
type testSSHSettings struct {
port int
timeout time.Duration
@@ -212,6 +372,7 @@ type mockHandler struct {
commands []string
capabilities map[string]bool
waitForCapabilities bool
+ sessionAcks []handlers.SessionAck
}
var _ handlers.Handler = (*mockHandler)(nil)
@@ -253,6 +414,19 @@ func (m *mockHandler) WaitForCapabilities(timeout time.Duration) bool {
return m.waitForCapabilities
}
+func (m *mockHandler) WaitForSessionAck(timeout time.Duration) (handlers.SessionAck, bool) {
+ if timeout <= 0 {
+ return handlers.SessionAck{}, false
+ }
+ if len(m.sessionAcks) == 0 {
+ return handlers.SessionAck{}, false
+ }
+
+ ack := m.sessionAcks[0]
+ m.sessionAcks = m.sessionAcks[1:]
+ return ack, true
+}
+
func (m *mockHandler) Read(_ []byte) (int, error) {
return 0, nil
}
diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go
index 0ebe069..72e3fda 100644
--- a/internal/clients/connectors/serverless.go
+++ b/internal/clients/connectors/serverless.go
@@ -8,6 +8,7 @@ import (
"github.com/mimecast/dtail/internal/clients/handlers"
"github.com/mimecast/dtail/internal/io/dlog"
serverHandlers "github.com/mimecast/dtail/internal/server/handlers"
+ sessionspec "github.com/mimecast/dtail/internal/session"
)
// ServerlessHandlerFactory creates the in-process server-side handler used by serverless mode.
@@ -19,6 +20,9 @@ type ServerlessHandlerFactory interface {
type Serverless struct {
handler handlers.Handler
commands []string
+ sessionSpec sessionspec.Spec
+ sessionState committedSessionState
+ interactive bool
userName string
handlerFactory ServerlessHandlerFactory
}
@@ -27,13 +31,16 @@ var _ Connector = (*Serverless)(nil)
// NewServerless starts a new serverless session.
func NewServerless(userName string, handler handlers.Handler,
- commands []string, handlerFactory ServerlessHandlerFactory) *Serverless {
+ commands []string, sessionSpec sessionspec.Spec, interactive bool,
+ handlerFactory ServerlessHandlerFactory) *Serverless {
dlog.Client.Debug("Creating new serverless connector", handler, commands)
return &Serverless{
userName: userName,
handler: handler,
commands: commands,
+ sessionSpec: sessionSpec,
+ interactive: interactive,
handlerFactory: handlerFactory,
}
}
@@ -54,6 +61,16 @@ func (s *Serverless) SupportsQueryUpdates(timeout time.Duration) bool {
return supportsQueryUpdates(s.handler, timeout)
}
+// ApplySessionSpec starts or updates the in-process interactive session state.
+func (s *Serverless) ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error {
+ return applySessionSpec(s.Server(), s.handler, &s.sessionState, spec, timeout)
+}
+
+// CommittedSession returns the last server-acknowledged session state.
+func (s *Serverless) CommittedSession() (sessionspec.Spec, uint64, bool) {
+ return s.sessionState.snapshot()
+}
+
// Start the serverless connection.
func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc,
throttleCh, statsCh chan struct{}) {
@@ -165,12 +182,8 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro
}
}()
- // Send commands after setting up the data flow
- for _, command := range s.commands {
- dlog.Client.Debug("Sending command to serverless server", command)
- if err := s.handler.SendMessage(command); err != nil {
- dlog.Client.Debug(err)
- }
+ if err := dispatchInitialCommands(s.Server(), s.handler, s.commands, s.interactive, s.sessionSpec, &s.sessionState); err != nil {
+ return err
}
// Monitor for completion
diff --git a/internal/clients/connectors/sessiontransport.go b/internal/clients/connectors/sessiontransport.go
new file mode 100644
index 0000000..84aeb78
--- /dev/null
+++ b/internal/clients/connectors/sessiontransport.go
@@ -0,0 +1,154 @@
+package connectors
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/omode"
+ sessionspec "github.com/mimecast/dtail/internal/session"
+)
+
+var (
+ // ErrSessionUnsupported indicates that the remote side did not advertise
+ // runtime query update support.
+ ErrSessionUnsupported = errors.New("runtime query updates unsupported by server")
+ // ErrSessionAckTimeout indicates that no hidden SESSION acknowledgement arrived in time.
+ ErrSessionAckTimeout = errors.New("timed out waiting for session acknowledgement")
+ // ErrSessionRejected indicates that the server explicitly rejected a SESSION request.
+ ErrSessionRejected = errors.New("session request rejected")
+ // ErrUnexpectedSessionAck indicates that the client received a malformed or mismatched acknowledgement.
+ ErrUnexpectedSessionAck = errors.New("unexpected session acknowledgement")
+)
+
+const defaultSessionAckTimeout = 2 * time.Second
+
+type committedSessionState struct {
+ mu sync.RWMutex
+ committed bool
+ generation uint64
+ spec sessionspec.Spec
+}
+
+func (s *committedSessionState) commit(spec sessionspec.Spec, generation uint64) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.committed = true
+ s.generation = generation
+ s.spec = spec
+}
+
+func (s *committedSessionState) clear() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.committed = false
+ s.generation = 0
+ s.spec = sessionspec.Spec{}
+}
+
+func (s *committedSessionState) snapshot() (sessionspec.Spec, uint64, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.spec, s.generation, s.committed
+}
+
+func dispatchInitialCommands(server string, handler handlers.Handler, commands []string,
+ interactiveQuery bool, initialSpec sessionspec.Spec, state *committedSessionState) error {
+
+ if !interactiveQuery || initialSpec.Mode == omode.Unknown {
+ return sendLegacyCommands(handler, commands)
+ }
+
+ if err := applySessionSpec(server, handler, state, initialSpec, defaultSessionAckTimeout); err != nil {
+ if !errors.Is(err, ErrSessionUnsupported) {
+ dlog.Client.Warn(server, "Interactive session bootstrap failed, falling back to legacy commands", err)
+ }
+ state.clear()
+ return sendLegacyCommands(handler, commands)
+ }
+
+ return nil
+}
+
+func applySessionSpec(server string, handler handlers.Handler,
+ state *committedSessionState, spec sessionspec.Spec, timeout time.Duration) error {
+
+ if handler == nil {
+ return ErrSessionUnsupported
+ }
+ if !supportsQueryUpdates(handler, defaultCapabilityWait) {
+ return ErrSessionUnsupported
+ }
+
+ action := "start"
+ nextGeneration := uint64(0)
+ command, err := spec.StartCommand()
+ if err != nil {
+ return err
+ }
+
+ if _, generation, ok := state.snapshot(); ok {
+ action = "update"
+ nextGeneration = generation + 1
+ command, err = spec.UpdateCommand(nextGeneration)
+ if err != nil {
+ return err
+ }
+ }
+
+ drainSessionAcks(handler)
+ if err := handler.SendMessage(command); err != nil {
+ return err
+ }
+
+ ack, ok := handler.WaitForSessionAck(resolveSessionAckTimeout(timeout))
+ if !ok {
+ return ErrSessionAckTimeout
+ }
+ if ack.Error != "" {
+ return fmt.Errorf("%w: %s", ErrSessionRejected, ack.Error)
+ }
+ if ack.Action != action {
+ return fmt.Errorf("%w: got action %q want %q", ErrUnexpectedSessionAck, ack.Action, action)
+ }
+ if ack.Generation == 0 {
+ return fmt.Errorf("%w: missing generation", ErrUnexpectedSessionAck)
+ }
+ if action == "update" && ack.Generation != nextGeneration {
+ return fmt.Errorf("%w: got generation %d want %d", ErrUnexpectedSessionAck, ack.Generation, nextGeneration)
+ }
+
+ state.commit(spec, ack.Generation)
+ dlog.Client.Debug(server, "Committed session spec", "action", action, "generation", ack.Generation)
+ return nil
+}
+
+func sendLegacyCommands(handler handlers.Handler, commands []string) error {
+ for _, command := range commands {
+ if err := handler.SendMessage(command); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func drainSessionAcks(handler handlers.Handler) {
+ for {
+ if _, ok := handler.WaitForSessionAck(0); !ok {
+ return
+ }
+ }
+}
+
+func resolveSessionAckTimeout(timeout time.Duration) time.Duration {
+ if timeout <= 0 {
+ return defaultSessionAckTimeout
+ }
+ return timeout
+}