diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/clients/baseclient.go | 5 | ||||
| -rw-r--r-- | internal/clients/connectors/connector.go | 7 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection.go | 29 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection_test.go | 174 | ||||
| -rw-r--r-- | internal/clients/connectors/serverless.go | 27 | ||||
| -rw-r--r-- | internal/clients/connectors/sessiontransport.go | 154 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler.go | 97 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler_test.go | 76 | ||||
| -rw-r--r-- | internal/clients/handlers/clienthandler.go | 1 | ||||
| -rw-r--r-- | internal/clients/handlers/handler.go | 1 | ||||
| -rw-r--r-- | internal/clients/handlers/healthhandler.go | 1 | ||||
| -rw-r--r-- | internal/clients/handlers/maprhandler.go | 1 | ||||
| -rw-r--r-- | internal/protocol/session.go | 10 | ||||
| -rw-r--r-- | internal/session/spec.go | 35 | ||||
| -rw-r--r-- | internal/session/spec_test.go | 73 |
15 files changed, 674 insertions, 17 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 999c0ed..de76bf1 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -213,9 +213,10 @@ func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMe hostKeyCallback client.HostKeyCallback) connectors.Connector { if c.Args.Serverless { return connectors.NewServerless(c.UserName, c.maker.makeHandler(server), - c.maker.makeCommands(), c.runtime) + c.maker.makeCommands(), c.sessionSpec, c.Args.InteractiveQuery, c.runtime) } return connectors.NewServerConnection(server, c.UserName, sshAuthMethods, hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands(), - c.Args.SSHPrivateKeyFilePath, c.Args.NoAuthKey, c.runtime) + c.sessionSpec, c.Args.InteractiveQuery, c.Args.SSHPrivateKeyFilePath, + c.Args.NoAuthKey, c.runtime) } diff --git a/internal/clients/connectors/connector.go b/internal/clients/connectors/connector.go index a803c33..c1211ec 100644 --- a/internal/clients/connectors/connector.go +++ b/internal/clients/connectors/connector.go @@ -5,6 +5,7 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" + sessionspec "github.com/mimecast/dtail/internal/session" ) // Connector interface. @@ -18,4 +19,10 @@ type Connector interface { // SupportsQueryUpdates reports whether the connected server advertised // runtime query replacement support within the given timeout. SupportsQueryUpdates(timeout time.Duration) bool + // ApplySessionSpec starts or updates the interactive session workload on an + // already connected server when query updates are supported. + ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error + // CommittedSession returns the last session spec and generation that the + // server acknowledged for this connection. + CommittedSession() (sessionspec.Spec, uint64, bool) } diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 97c02eb..5b432d4 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -14,6 +14,7 @@ import ( "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/protocol" + sessionspec "github.com/mimecast/dtail/internal/session" "github.com/mimecast/dtail/internal/ssh/client" "golang.org/x/crypto/ssh" @@ -43,6 +44,9 @@ type ServerConnection struct { config *ssh.ClientConfig handler handlers.Handler commands []string + sessionSpec sessionspec.Spec + sessionState committedSessionState + interactive bool authKeyPath string authKeyDisabled bool hostKeyCallback client.HostKeyCallback @@ -54,8 +58,8 @@ var _ Connector = (*ServerConnection)(nil) // NewServerConnection returns a new DTail SSH server connection. func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, - handler handlers.Handler, commands []string, authKeyPath string, - authKeyDisabled bool, settings SSHSettings) *ServerConnection { + handler handlers.Handler, commands []string, sessionSpec sessionspec.Spec, + interactive bool, authKeyPath string, authKeyDisabled bool, settings SSHSettings) *ServerConnection { dlog.Client.Debug(server, "Creating new connection", server, handler, commands) sshConnectTimeout := defaultSSHConnectTimeout @@ -76,6 +80,8 @@ func NewServerConnection(server string, userName string, server: server, handler: handler, commands: commands, + sessionSpec: sessionSpec, + interactive: interactive, authKeyPath: resolveAuthKeyPath(authKeyPath), authKeyDisabled: authKeyDisabled, config: &ssh.ClientConfig{ @@ -103,6 +109,17 @@ func (c *ServerConnection) SupportsQueryUpdates(timeout time.Duration) bool { return supportsQueryUpdates(c.handler, timeout) } +// ApplySessionSpec starts or updates the interactive session state on the +// existing SSH connection when runtime query updates are supported. +func (c *ServerConnection) ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error { + return applySessionSpec(c.server, c.handler, &c.sessionState, spec, timeout) +} + +// CommittedSession returns the last server-acknowledged session state. +func (c *ServerConnection) CommittedSession() (sessionspec.Spec, uint64, bool) { + return c.sessionState.snapshot() +} + // Attempt to parse the server port address from the provided server FQDN. func (c *ServerConnection) initServerPort(defaultPort int) { parts := strings.Split(c.server, ":") @@ -257,12 +274,8 @@ func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc c.sendAuthKeyRegistrationCommand() } - // Send all requested commands to the server. - for _, command := range c.commands { - dlog.Client.Debug(command) - if err := c.handler.SendMessage(command); err != nil { - dlog.Client.Debug(err) - } + if err := dispatchInitialCommands(c.server, c.handler, c.commands, c.interactive, c.sessionSpec, &c.sessionState); err != nil { + return err } if !c.throttlingDone { diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go index 9307b24..76c4eb6 100644 --- a/internal/clients/connectors/serverconnection_test.go +++ b/internal/clients/connectors/serverconnection_test.go @@ -2,6 +2,7 @@ package connectors import ( "context" + "errors" "os" "path/filepath" "testing" @@ -9,7 +10,9 @@ import ( "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/protocol" + sessionspec "github.com/mimecast/dtail/internal/session" "golang.org/x/crypto/ssh" ) @@ -91,6 +94,8 @@ func TestNewServerConnectionUsesInjectedSettings(t *testing.T) { testHostKeyCallback{}, &mockHandler{}, nil, + sessionspec.Spec{}, + false, "", false, testSSHSettings{port: 3022, timeout: 5 * time.Second}, @@ -117,6 +122,8 @@ func TestNewServerConnectionFallsBackToDefaults(t *testing.T) { testHostKeyCallback{}, &mockHandler{}, nil, + sessionspec.Spec{}, + false, "", false, testSSHSettings{}, @@ -173,6 +180,159 @@ func TestServerConnectionSupportsQueryUpdatesRequiresCapabilityFlag(t *testing.T } } +func TestServerConnectionApplySessionSpecStart(t *testing.T) { + resetClientLogger(t) + + conn := &ServerConnection{ + server: "srv1", + handler: &mockHandler{ + waitForCapabilities: true, + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + sessionAcks: []handlers.SessionAck{{ + Action: "start", + Generation: 1, + }}, + }, + } + + spec := sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + if err := conn.ApplySessionSpec(spec, 10*time.Millisecond); err != nil { + t.Fatalf("ApplySessionSpec() error = %v", err) + } + + mock := conn.handler.(*mockHandler) + if len(mock.commands) != 1 { + t.Fatalf("expected one session command, got %d", len(mock.commands)) + } + if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 1 || committedSpec.Regex != "ERROR" { + t.Fatalf("unexpected committed session: spec=%#v generation=%d ok=%v", committedSpec, generation, ok) + } +} + +func TestServerConnectionApplySessionSpecUpdateUsesNextGeneration(t *testing.T) { + resetClientLogger(t) + + mock := &mockHandler{ + waitForCapabilities: true, + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + sessionAcks: []handlers.SessionAck{ + {Action: "start", Generation: 4}, + {Action: "update", Generation: 5}, + }, + } + conn := &ServerConnection{ + server: "srv1", + handler: mock, + } + + startSpec := sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + updateSpec := sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "WARN", + } + + if err := conn.ApplySessionSpec(startSpec, 10*time.Millisecond); err != nil { + t.Fatalf("start ApplySessionSpec() error = %v", err) + } + if err := conn.ApplySessionSpec(updateSpec, 10*time.Millisecond); err != nil { + t.Fatalf("update ApplySessionSpec() error = %v", err) + } + if len(mock.commands) != 2 { + t.Fatalf("expected two session commands, got %d", len(mock.commands)) + } + if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 5 || committedSpec.Regex != "WARN" { + t.Fatalf("unexpected committed session after update: spec=%#v generation=%d ok=%v", committedSpec, generation, ok) + } +} + +func TestServerConnectionApplySessionSpecFallsBackForUnsupportedServer(t *testing.T) { + resetClientLogger(t) + + conn := &ServerConnection{ + handler: &mockHandler{}, + } + + err := conn.ApplySessionSpec(sessionspec.Spec{Mode: omode.TailClient, Regex: "ERROR"}, 5*time.Millisecond) + if !errors.Is(err, ErrSessionUnsupported) { + t.Fatalf("expected ErrSessionUnsupported, got %v", err) + } +} + +func TestServerConnectionApplySessionSpecPreservesCommittedStateOnRejectedUpdate(t *testing.T) { + resetClientLogger(t) + + mock := &mockHandler{ + waitForCapabilities: true, + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + sessionAcks: []handlers.SessionAck{ + {Action: "start", Generation: 2}, + {Action: "error", Error: "bad reload"}, + }, + } + conn := &ServerConnection{ + server: "srv1", + handler: mock, + } + + startSpec := sessionspec.Spec{Mode: omode.TailClient, Regex: "ERROR"} + if err := conn.ApplySessionSpec(startSpec, 10*time.Millisecond); err != nil { + t.Fatalf("start ApplySessionSpec() error = %v", err) + } + + err := conn.ApplySessionSpec(sessionspec.Spec{Mode: omode.TailClient, Regex: "WARN"}, 10*time.Millisecond) + if !errors.Is(err, ErrSessionRejected) { + t.Fatalf("expected ErrSessionRejected, got %v", err) + } + if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 2 || committedSpec.Regex != "ERROR" { + t.Fatalf("unexpected committed session after rejected update: spec=%#v generation=%d ok=%v", committedSpec, generation, ok) + } +} + +func TestServerConnectionApplySessionSpecRejectsUnexpectedAck(t *testing.T) { + resetClientLogger(t) + + mock := &mockHandler{ + waitForCapabilities: true, + capabilities: map[string]bool{ + protocol.CapabilityQueryUpdateV1: true, + }, + sessionAcks: []handlers.SessionAck{ + {Action: "update", Generation: 1}, + }, + } + conn := &ServerConnection{ + server: "srv1", + handler: mock, + } + + err := conn.ApplySessionSpec(sessionspec.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + }, 10*time.Millisecond) + if !errors.Is(err, ErrUnexpectedSessionAck) { + t.Fatalf("expected ErrUnexpectedSessionAck, got %v", err) + } + if _, _, ok := conn.CommittedSession(); ok { + t.Fatalf("unexpected committed session after mismatched ack") + } +} + type testSSHSettings struct { port int timeout time.Duration @@ -212,6 +372,7 @@ type mockHandler struct { commands []string capabilities map[string]bool waitForCapabilities bool + sessionAcks []handlers.SessionAck } var _ handlers.Handler = (*mockHandler)(nil) @@ -253,6 +414,19 @@ func (m *mockHandler) WaitForCapabilities(timeout time.Duration) bool { return m.waitForCapabilities } +func (m *mockHandler) WaitForSessionAck(timeout time.Duration) (handlers.SessionAck, bool) { + if timeout <= 0 { + return handlers.SessionAck{}, false + } + if len(m.sessionAcks) == 0 { + return handlers.SessionAck{}, false + } + + ack := m.sessionAcks[0] + m.sessionAcks = m.sessionAcks[1:] + return ack, true +} + func (m *mockHandler) Read(_ []byte) (int, error) { return 0, nil } diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go index 0ebe069..72e3fda 100644 --- a/internal/clients/connectors/serverless.go +++ b/internal/clients/connectors/serverless.go @@ -8,6 +8,7 @@ import ( "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/io/dlog" serverHandlers "github.com/mimecast/dtail/internal/server/handlers" + sessionspec "github.com/mimecast/dtail/internal/session" ) // ServerlessHandlerFactory creates the in-process server-side handler used by serverless mode. @@ -19,6 +20,9 @@ type ServerlessHandlerFactory interface { type Serverless struct { handler handlers.Handler commands []string + sessionSpec sessionspec.Spec + sessionState committedSessionState + interactive bool userName string handlerFactory ServerlessHandlerFactory } @@ -27,13 +31,16 @@ var _ Connector = (*Serverless)(nil) // NewServerless starts a new serverless session. func NewServerless(userName string, handler handlers.Handler, - commands []string, handlerFactory ServerlessHandlerFactory) *Serverless { + commands []string, sessionSpec sessionspec.Spec, interactive bool, + handlerFactory ServerlessHandlerFactory) *Serverless { dlog.Client.Debug("Creating new serverless connector", handler, commands) return &Serverless{ userName: userName, handler: handler, commands: commands, + sessionSpec: sessionSpec, + interactive: interactive, handlerFactory: handlerFactory, } } @@ -54,6 +61,16 @@ func (s *Serverless) SupportsQueryUpdates(timeout time.Duration) bool { return supportsQueryUpdates(s.handler, timeout) } +// ApplySessionSpec starts or updates the in-process interactive session state. +func (s *Serverless) ApplySessionSpec(spec sessionspec.Spec, timeout time.Duration) error { + return applySessionSpec(s.Server(), s.handler, &s.sessionState, spec, timeout) +} + +// CommittedSession returns the last server-acknowledged session state. +func (s *Serverless) CommittedSession() (sessionspec.Spec, uint64, bool) { + return s.sessionState.snapshot() +} + // Start the serverless connection. func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { @@ -165,12 +182,8 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro } }() - // Send commands after setting up the data flow - for _, command := range s.commands { - dlog.Client.Debug("Sending command to serverless server", command) - if err := s.handler.SendMessage(command); err != nil { - dlog.Client.Debug(err) - } + if err := dispatchInitialCommands(s.Server(), s.handler, s.commands, s.interactive, s.sessionSpec, &s.sessionState); err != nil { + return err } // Monitor for completion diff --git a/internal/clients/connectors/sessiontransport.go b/internal/clients/connectors/sessiontransport.go new file mode 100644 index 0000000..84aeb78 --- /dev/null +++ b/internal/clients/connectors/sessiontransport.go @@ -0,0 +1,154 @@ +package connectors + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/omode" + sessionspec "github.com/mimecast/dtail/internal/session" +) + +var ( + // ErrSessionUnsupported indicates that the remote side did not advertise + // runtime query update support. + ErrSessionUnsupported = errors.New("runtime query updates unsupported by server") + // ErrSessionAckTimeout indicates that no hidden SESSION acknowledgement arrived in time. + ErrSessionAckTimeout = errors.New("timed out waiting for session acknowledgement") + // ErrSessionRejected indicates that the server explicitly rejected a SESSION request. + ErrSessionRejected = errors.New("session request rejected") + // ErrUnexpectedSessionAck indicates that the client received a malformed or mismatched acknowledgement. + ErrUnexpectedSessionAck = errors.New("unexpected session acknowledgement") +) + +const defaultSessionAckTimeout = 2 * time.Second + +type committedSessionState struct { + mu sync.RWMutex + committed bool + generation uint64 + spec sessionspec.Spec +} + +func (s *committedSessionState) commit(spec sessionspec.Spec, generation uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.committed = true + s.generation = generation + s.spec = spec +} + +func (s *committedSessionState) clear() { + s.mu.Lock() + defer s.mu.Unlock() + + s.committed = false + s.generation = 0 + s.spec = sessionspec.Spec{} +} + +func (s *committedSessionState) snapshot() (sessionspec.Spec, uint64, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.spec, s.generation, s.committed +} + +func dispatchInitialCommands(server string, handler handlers.Handler, commands []string, + interactiveQuery bool, initialSpec sessionspec.Spec, state *committedSessionState) error { + + if !interactiveQuery || initialSpec.Mode == omode.Unknown { + return sendLegacyCommands(handler, commands) + } + + if err := applySessionSpec(server, handler, state, initialSpec, defaultSessionAckTimeout); err != nil { + if !errors.Is(err, ErrSessionUnsupported) { + dlog.Client.Warn(server, "Interactive session bootstrap failed, falling back to legacy commands", err) + } + state.clear() + return sendLegacyCommands(handler, commands) + } + + return nil +} + +func applySessionSpec(server string, handler handlers.Handler, + state *committedSessionState, spec sessionspec.Spec, timeout time.Duration) error { + + if handler == nil { + return ErrSessionUnsupported + } + if !supportsQueryUpdates(handler, defaultCapabilityWait) { + return ErrSessionUnsupported + } + + action := "start" + nextGeneration := uint64(0) + command, err := spec.StartCommand() + if err != nil { + return err + } + + if _, generation, ok := state.snapshot(); ok { + action = "update" + nextGeneration = generation + 1 + command, err = spec.UpdateCommand(nextGeneration) + if err != nil { + return err + } + } + + drainSessionAcks(handler) + if err := handler.SendMessage(command); err != nil { + return err + } + + ack, ok := handler.WaitForSessionAck(resolveSessionAckTimeout(timeout)) + if !ok { + return ErrSessionAckTimeout + } + if ack.Error != "" { + return fmt.Errorf("%w: %s", ErrSessionRejected, ack.Error) + } + if ack.Action != action { + return fmt.Errorf("%w: got action %q want %q", ErrUnexpectedSessionAck, ack.Action, action) + } + if ack.Generation == 0 { + return fmt.Errorf("%w: missing generation", ErrUnexpectedSessionAck) + } + if action == "update" && ack.Generation != nextGeneration { + return fmt.Errorf("%w: got generation %d want %d", ErrUnexpectedSessionAck, ack.Generation, nextGeneration) + } + + state.commit(spec, ack.Generation) + dlog.Client.Debug(server, "Committed session spec", "action", action, "generation", ack.Generation) + return nil +} + +func sendLegacyCommands(handler handlers.Handler, commands []string) error { + for _, command := range commands { + if err := handler.SendMessage(command); err != nil { + return err + } + } + return nil +} + +func drainSessionAcks(handler handlers.Handler) { + for { + if _, ok := handler.WaitForSessionAck(0); !ok { + return + } + } +} + +func resolveSessionAckTimeout(timeout time.Duration) time.Duration { + if timeout <= 0 { + return defaultSessionAckTimeout + } + return timeout +} diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 923b24a..8da4556 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "sort" + "strconv" "strings" "sync" "time" @@ -27,6 +28,15 @@ type baseHandler struct { capabilities map[string]struct{} capabilitiesCh chan struct{} capabilitiesOk sync.Once + + sessionAcks chan SessionAck +} + +// SessionAck is a parsed hidden acknowledgement for SESSION START/UPDATE requests. +type SessionAck struct { + Action string + Generation uint64 + Error string } func (h *baseHandler) String() string { @@ -177,6 +187,10 @@ func (h *baseHandler) handleHiddenMessage(message string) { switch { case strings.HasPrefix(message, protocol.HiddenCapabilitiesPrefix): h.handleCapabilitiesMessage(message) + case strings.HasPrefix(message, protocol.HiddenSessionStartOKPrefix), + strings.HasPrefix(message, protocol.HiddenSessionUpdateOKPrefix), + strings.HasPrefix(message, protocol.HiddenSessionErrorPrefix): + h.handleSessionAckMessage(message) case strings.HasPrefix(message, ".syn close connection"): go h.SendMessage(".ack close connection") h.Shutdown() @@ -237,6 +251,89 @@ func (h *baseHandler) WaitForCapabilities(timeout time.Duration) bool { } } +func (h *baseHandler) WaitForSessionAck(timeout time.Duration) (SessionAck, bool) { + if h.sessionAcks == nil { + return SessionAck{}, false + } + + if timeout <= 0 { + select { + case ack := <-h.sessionAcks: + return ack, true + default: + return SessionAck{}, false + } + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case ack := <-h.sessionAcks: + return ack, true + case <-h.Done(): + return SessionAck{}, false + case <-timer.C: + return SessionAck{}, false + } +} + func (h *baseHandler) Shutdown() { h.done.Shutdown() } + +func (h *baseHandler) handleSessionAckMessage(message string) { + ack, ok := parseSessionAckMessage(message) + if !ok { + dlog.Client.Warn(h.server, "Unable to parse session acknowledgement", message) + return + } + if h.sessionAcks == nil { + return + } + + select { + case h.sessionAcks <- ack: + case <-h.Done(): + default: + dlog.Client.Warn(h.server, "Dropping session acknowledgement because the queue is full", message) + } +} + +func parseSessionAckMessage(message string) (SessionAck, bool) { + payload := strings.TrimSpace(message) + if payload == "" { + return SessionAck{}, false + } + + switch { + case strings.HasPrefix(payload, protocol.HiddenSessionStartOKPrefix): + return parseSessionOKAck(strings.TrimPrefix(payload, protocol.HiddenSessionStartOKPrefix), "start") + case strings.HasPrefix(payload, protocol.HiddenSessionUpdateOKPrefix): + return parseSessionOKAck(strings.TrimPrefix(payload, protocol.HiddenSessionUpdateOKPrefix), "update") + case strings.HasPrefix(payload, protocol.HiddenSessionErrorPrefix): + return SessionAck{ + Action: "error", + Error: strings.TrimSpace(strings.TrimPrefix(payload, protocol.HiddenSessionErrorPrefix)), + }, true + default: + return SessionAck{}, false + } +} + +func parseSessionOKAck(payload string, action string) (SessionAck, bool) { + generationStr := strings.TrimSpace(payload) + if generationStr == "" { + return SessionAck{}, false + } + + generation, err := strconv.ParseUint(generationStr, 10, 64) + if err != nil { + return SessionAck{}, false + } + + return SessionAck{ + Action: action, + Generation: generation, + }, true +} diff --git a/internal/clients/handlers/basehandler_test.go b/internal/clients/handlers/basehandler_test.go index fddc890..7db2bb8 100644 --- a/internal/clients/handlers/basehandler_test.go +++ b/internal/clients/handlers/basehandler_test.go @@ -65,6 +65,7 @@ func TestHandleCapabilitiesMessage(t *testing.T) { done: internal.NewDone(), capabilities: make(map[string]struct{}), capabilitiesCh: make(chan struct{}), + sessionAcks: make(chan SessionAck, 1), } handler.handleHiddenMessage(".syn capabilities query-update-v1 feature-two") @@ -90,9 +91,84 @@ func TestWaitForCapabilitiesTimeout(t *testing.T) { done: internal.NewDone(), capabilities: make(map[string]struct{}), capabilitiesCh: make(chan struct{}), + sessionAcks: make(chan SessionAck, 1), } if handler.WaitForCapabilities(5 * time.Millisecond) { t.Fatalf("expected capabilities wait to time out") } } + +func TestParseSessionAckMessage(t *testing.T) { + tests := []struct { + name string + message string + want SessionAck + wantOK bool + }{ + { + name: "start ok", + message: ".syn session start ok 7", + want: SessionAck{ + Action: "start", + Generation: 7, + }, + wantOK: true, + }, + { + name: "update ok", + message: ".syn session update ok 8", + want: SessionAck{ + Action: "update", + Generation: 8, + }, + wantOK: true, + }, + { + name: "error", + message: ".syn session err query sessions not supported yet", + want: SessionAck{ + Action: "error", + Error: "query sessions not supported yet", + }, + wantOK: true, + }, + { + name: "invalid", + message: ".syn session start ok nope", + wantOK: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, ok := parseSessionAckMessage(tc.message) + if ok != tc.wantOK { + t.Fatalf("unexpected ok flag: got %v want %v", ok, tc.wantOK) + } + if !tc.wantOK { + return + } + if got != tc.want { + t.Fatalf("unexpected ack: got %#v want %#v", got, tc.want) + } + }) + } +} + +func TestHandleSessionAckMessage(t *testing.T) { + handler := baseHandler{ + done: internal.NewDone(), + sessionAcks: make(chan SessionAck, 1), + } + + handler.handleHiddenMessage(".syn session update ok 4") + + ack, ok := handler.WaitForSessionAck(10 * time.Millisecond) + if !ok { + t.Fatalf("expected session ack") + } + if ack.Action != "update" || ack.Generation != 4 { + t.Fatalf("unexpected session ack: %#v", ack) + } +} diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go index 61bbc50..3998e9f 100644 --- a/internal/clients/handlers/clienthandler.go +++ b/internal/clients/handlers/clienthandler.go @@ -25,6 +25,7 @@ func NewClientHandler(server string) *ClientHandler { done: internal.NewDone(), capabilities: make(map[string]struct{}), capabilitiesCh: make(chan struct{}), + sessionAcks: make(chan SessionAck, 4), }, } } diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go index aebebaa..cac78ad 100644 --- a/internal/clients/handlers/handler.go +++ b/internal/clients/handlers/handler.go @@ -16,4 +16,5 @@ type Handler interface { Shutdown() Done() <-chan struct{} WaitForCapabilities(timeout time.Duration) bool + WaitForSessionAck(timeout time.Duration) (SessionAck, bool) } diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go index cd5605a..763ba88 100644 --- a/internal/clients/handlers/healthhandler.go +++ b/internal/clients/handlers/healthhandler.go @@ -26,6 +26,7 @@ func NewHealthHandler(server string) *HealthHandler { done: internal.NewDone(), capabilities: make(map[string]struct{}), capabilitiesCh: make(chan struct{}), + sessionAcks: make(chan SessionAck, 4), }, } } diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go index 9e9a0d1..5a16d13 100644 --- a/internal/clients/handlers/maprhandler.go +++ b/internal/clients/handlers/maprhandler.go @@ -32,6 +32,7 @@ func NewMaprHandler(server string, query *mapr.Query, done: internal.NewDone(), capabilities: make(map[string]struct{}), capabilitiesCh: make(chan struct{}), + sessionAcks: make(chan SessionAck, 4), }, query: query, aggregate: client.NewAggregate(server, query, globalGroup), diff --git a/internal/protocol/session.go b/internal/protocol/session.go new file mode 100644 index 0000000..ec6433c --- /dev/null +++ b/internal/protocol/session.go @@ -0,0 +1,10 @@ +package protocol + +const ( + // HiddenSessionStartOKPrefix acknowledges a successful SESSION START request. + HiddenSessionStartOKPrefix = ".syn session start ok" + // HiddenSessionUpdateOKPrefix acknowledges a successful SESSION UPDATE request. + HiddenSessionUpdateOKPrefix = ".syn session update ok" + // HiddenSessionErrorPrefix reports a rejected SESSION request. + HiddenSessionErrorPrefix = ".syn session err " +) diff --git a/internal/session/spec.go b/internal/session/spec.go index 2d1b77d..0a6ad4e 100644 --- a/internal/session/spec.go +++ b/internal/session/spec.go @@ -1,6 +1,8 @@ package session import ( + "encoding/base64" + "encoding/json" "fmt" "strings" @@ -45,6 +47,30 @@ func (s Spec) Commands() ([]string, error) { } } +// StartCommand returns the SESSION START command for this specification. +func (s Spec) StartCommand() (string, error) { + payload, err := s.encodedPayload() + if err != nil { + return "", err + } + + return fmt.Sprintf("SESSION START %s", payload), nil +} + +// UpdateCommand returns the SESSION UPDATE command for this specification. +func (s Spec) UpdateCommand(generation uint64) (string, error) { + payload, err := s.encodedPayload() + if err != nil { + return "", err + } + + if generation == 0 { + return fmt.Sprintf("SESSION UPDATE %s", payload), nil + } + + return fmt.Sprintf("SESSION UPDATE %d %s", generation, payload), nil +} + func (s Spec) queryCommands() ([]string, error) { if s.Mode != omode.MapClient && s.Mode != omode.TailClient { return nil, fmt.Errorf("session spec query mode requires map or tail mode, got %s", s.Mode) @@ -122,3 +148,12 @@ func splitFiles(what string) []string { } return files } + +func (s Spec) encodedPayload() (string, error) { + payload, err := json.Marshal(s) + if err != nil { + return "", fmt.Errorf("marshal session spec: %w", err) + } + + return base64.StdEncoding.EncodeToString(payload), nil +} diff --git a/internal/session/spec_test.go b/internal/session/spec_test.go new file mode 100644 index 0000000..182c517 --- /dev/null +++ b/internal/session/spec_test.go @@ -0,0 +1,73 @@ +package session + +import ( + "encoding/base64" + "encoding/json" + "reflect" + "strings" + "testing" + + "github.com/mimecast/dtail/internal/omode" +) + +func TestSpecStartCommandEncodesPayload(t *testing.T) { + t.Parallel() + + spec := Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Options: "plain=true", + Regex: "ERROR", + Timeout: 15, + } + + command, err := spec.StartCommand() + if err != nil { + t.Fatalf("StartCommand() error = %v", err) + } + if !strings.HasPrefix(command, "SESSION START ") { + t.Fatalf("unexpected start command prefix: %q", command) + } + + var decoded Spec + if err := decodeSpecPayload(strings.TrimPrefix(command, "SESSION START "), &decoded); err != nil { + t.Fatalf("decode start payload: %v", err) + } + if !reflect.DeepEqual(decoded, spec) { + t.Fatalf("unexpected decoded spec: got %#v want %#v", decoded, spec) + } +} + +func TestSpecUpdateCommandIncludesGeneration(t *testing.T) { + t.Parallel() + + spec := Spec{ + Mode: omode.MapClient, + Files: []string{"/var/log/app.log"}, + Query: "from STATS select count(*)", + } + + command, err := spec.UpdateCommand(7) + if err != nil { + t.Fatalf("UpdateCommand() error = %v", err) + } + if !strings.HasPrefix(command, "SESSION UPDATE 7 ") { + t.Fatalf("unexpected update command prefix: %q", command) + } + + var decoded Spec + if err := decodeSpecPayload(strings.TrimPrefix(command, "SESSION UPDATE 7 "), &decoded); err != nil { + t.Fatalf("decode update payload: %v", err) + } + if !reflect.DeepEqual(decoded, spec) { + t.Fatalf("unexpected decoded spec: got %#v want %#v", decoded, spec) + } +} + +func decodeSpecPayload(payload string, out *Spec) error { + raw, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return err + } + return json.Unmarshal(raw, out) +} |
