diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-13 08:48:13 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-13 08:48:13 +0200 |
| commit | 4b15cf31f069bb47a47f3fced9322091997edf15 (patch) | |
| tree | 2a4af63f9a9b76c86c7a9f4595c58c877953a31e /internal/clients/connectors/sessiontransport.go | |
| parent | 8ac1d12f8261bce508c99be454ce27df9c1af0a9 (diff) | |
task 407: add client session transport
Diffstat (limited to 'internal/clients/connectors/sessiontransport.go')
| -rw-r--r-- | internal/clients/connectors/sessiontransport.go | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/internal/clients/connectors/sessiontransport.go b/internal/clients/connectors/sessiontransport.go new file mode 100644 index 0000000..84aeb78 --- /dev/null +++ b/internal/clients/connectors/sessiontransport.go @@ -0,0 +1,154 @@ +package connectors + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/omode" + sessionspec "github.com/mimecast/dtail/internal/session" +) + +var ( + // ErrSessionUnsupported indicates that the remote side did not advertise + // runtime query update support. + ErrSessionUnsupported = errors.New("runtime query updates unsupported by server") + // ErrSessionAckTimeout indicates that no hidden SESSION acknowledgement arrived in time. + ErrSessionAckTimeout = errors.New("timed out waiting for session acknowledgement") + // ErrSessionRejected indicates that the server explicitly rejected a SESSION request. + ErrSessionRejected = errors.New("session request rejected") + // ErrUnexpectedSessionAck indicates that the client received a malformed or mismatched acknowledgement. + ErrUnexpectedSessionAck = errors.New("unexpected session acknowledgement") +) + +const defaultSessionAckTimeout = 2 * time.Second + +type committedSessionState struct { + mu sync.RWMutex + committed bool + generation uint64 + spec sessionspec.Spec +} + +func (s *committedSessionState) commit(spec sessionspec.Spec, generation uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.committed = true + s.generation = generation + s.spec = spec +} + +func (s *committedSessionState) clear() { + s.mu.Lock() + defer s.mu.Unlock() + + s.committed = false + s.generation = 0 + s.spec = sessionspec.Spec{} +} + +func (s *committedSessionState) snapshot() (sessionspec.Spec, uint64, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.spec, s.generation, s.committed +} + +func dispatchInitialCommands(server string, handler handlers.Handler, commands []string, + interactiveQuery bool, initialSpec sessionspec.Spec, state *committedSessionState) error { + + if !interactiveQuery || initialSpec.Mode == omode.Unknown { + return sendLegacyCommands(handler, commands) + } + + if err := applySessionSpec(server, handler, state, initialSpec, defaultSessionAckTimeout); err != nil { + if !errors.Is(err, ErrSessionUnsupported) { + dlog.Client.Warn(server, "Interactive session bootstrap failed, falling back to legacy commands", err) + } + state.clear() + return sendLegacyCommands(handler, commands) + } + + return nil +} + +func applySessionSpec(server string, handler handlers.Handler, + state *committedSessionState, spec sessionspec.Spec, timeout time.Duration) error { + + if handler == nil { + return ErrSessionUnsupported + } + if !supportsQueryUpdates(handler, defaultCapabilityWait) { + return ErrSessionUnsupported + } + + action := "start" + nextGeneration := uint64(0) + command, err := spec.StartCommand() + if err != nil { + return err + } + + if _, generation, ok := state.snapshot(); ok { + action = "update" + nextGeneration = generation + 1 + command, err = spec.UpdateCommand(nextGeneration) + if err != nil { + return err + } + } + + drainSessionAcks(handler) + if err := handler.SendMessage(command); err != nil { + return err + } + + ack, ok := handler.WaitForSessionAck(resolveSessionAckTimeout(timeout)) + if !ok { + return ErrSessionAckTimeout + } + if ack.Error != "" { + return fmt.Errorf("%w: %s", ErrSessionRejected, ack.Error) + } + if ack.Action != action { + return fmt.Errorf("%w: got action %q want %q", ErrUnexpectedSessionAck, ack.Action, action) + } + if ack.Generation == 0 { + return fmt.Errorf("%w: missing generation", ErrUnexpectedSessionAck) + } + if action == "update" && ack.Generation != nextGeneration { + return fmt.Errorf("%w: got generation %d want %d", ErrUnexpectedSessionAck, ack.Generation, nextGeneration) + } + + state.commit(spec, ack.Generation) + dlog.Client.Debug(server, "Committed session spec", "action", action, "generation", ack.Generation) + return nil +} + +func sendLegacyCommands(handler handlers.Handler, commands []string) error { + for _, command := range commands { + if err := handler.SendMessage(command); err != nil { + return err + } + } + return nil +} + +func drainSessionAcks(handler handlers.Handler) { + for { + if _, ok := handler.WaitForSessionAck(0); !ok { + return + } + } +} + +func resolveSessionAckTimeout(timeout time.Duration) time.Duration { + if timeout <= 0 { + return defaultSessionAckTimeout + } + return timeout +} |
