diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-13 07:59:45 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-13 07:59:45 +0200 |
| commit | 6dbc03d5c7b6068665e2d95bc66c4f3700323dc8 (patch) | |
| tree | dc14218fd578caca6b0a7ada3ceb1a0f060a9a9e /internal | |
| parent | c88dddee1953c938b47830ec13696f23770eb22d (diff) | |
task 398: implement session preemption
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 2 | ||||
| -rw-r--r-- | internal/server/handlers/sessioncommand.go | 128 | ||||
| -rw-r--r-- | internal/server/handlers/sessioncommand_test.go | 177 |
3 files changed, 279 insertions, 28 deletions
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index e8c234b..79d03b8 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -98,7 +98,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon // Only shutdown if no active commands AND no pending files. // AUTHKEY is a session-side effect command and should not terminate the shell // because user commands may still follow in the same session. - if shutdownOnCompletion && activeCommands == 0 && pendingFiles == 0 { + if shutdownOnCompletion && activeCommands == 0 && pendingFiles == 0 && !h.sessionState.keepAlive() { h.shutdown() } } diff --git a/internal/server/handlers/sessioncommand.go b/internal/server/handlers/sessioncommand.go index bc5f83e..351f27e 100644 --- a/internal/server/handlers/sessioncommand.go +++ b/internal/server/handlers/sessioncommand.go @@ -9,9 +9,9 @@ import ( "strings" "sync" + "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/lcontext" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/protocol" "github.com/mimecast/dtail/internal/session" ) @@ -26,6 +26,7 @@ type sessionCommandState struct { active bool generation uint64 spec session.Spec + cancel context.CancelFunc } func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) { @@ -39,15 +40,23 @@ func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LCont switch action { case "START": - h.sessionState.storeStart(spec) - h.send(h.serverMessages, sessionAckStartOKPrefix) + generation, err = h.sessionState.start(h, spec) + if err != nil { + h.send(h.serverMessages, sessionAckErrorPrefix+err.Error()) + return + } + h.send(h.serverMessages, fmt.Sprintf("%s %d", sessionAckStartOKPrefix, generation)) case "UPDATE": if !h.sessionState.activeSession() { h.send(h.serverMessages, sessionAckErrorPrefix+"session not started") return } - h.sessionState.storeUpdate(spec, generation) - h.send(h.serverMessages, sessionAckUpdateOKPrefix) + generation, err = h.sessionState.update(h, spec, generation) + if err != nil { + h.send(h.serverMessages, sessionAckErrorPrefix+err.Error()) + return + } + h.send(h.serverMessages, fmt.Sprintf("%s %d", sessionAckUpdateOKPrefix, generation)) default: h.send(h.serverMessages, sessionAckErrorPrefix+"unknown action") } @@ -97,6 +106,10 @@ func validateSessionSpec(spec session.Spec) error { return fmt.Errorf("missing session query") } + if err := validateSessionOptions(spec.Options); err != nil { + return err + } + if _, err := spec.Commands(); err != nil { return fmt.Errorf("invalid session spec") } @@ -104,25 +117,91 @@ func validateSessionSpec(spec session.Spec) error { return nil } -func (s *sessionCommandState) storeStart(spec session.Spec) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *sessionCommandState) start(handler *ServerHandler, spec session.Spec) (uint64, error) { + commands, err := prepareSessionCommands(spec) + if err != nil { + return 0, err + } + s.mu.Lock() + if s.active { + s.mu.Unlock() + return 0, fmt.Errorf("session already started") + } + ctx, cancel := handler.newCommandContext(context.Background()) s.active = true s.generation = 1 s.spec = spec + s.cancel = cancel + s.mu.Unlock() + + if err := handler.dispatchSessionCommands(ctx, commands); err != nil { + cancel() + s.reset() + return 0, err + } + + return 1, nil } -func (s *sessionCommandState) storeUpdate(spec session.Spec, generation uint64) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *sessionCommandState) update(handler *ServerHandler, spec session.Spec, generation uint64) (uint64, error) { + commands, err := prepareSessionCommands(spec) + if err != nil { + return 0, err + } - s.active = true + s.mu.Lock() + if !s.active { + s.mu.Unlock() + return 0, fmt.Errorf("session not started") + } + oldCancel := s.cancel + ctx, cancel := handler.newCommandContext(context.Background()) if generation == 0 { generation = s.generation + 1 } + s.active = true s.generation = generation s.spec = spec + s.cancel = cancel + s.mu.Unlock() + + if oldCancel != nil { + oldCancel() + } + + if err := handler.dispatchSessionCommands(ctx, commands); err != nil { + cancel() + s.reset() + return 0, err + } + + return generation, nil +} + +func prepareSessionCommands(spec session.Spec) ([]string, error) { + if spec.Query != "" { + return nil, fmt.Errorf("query sessions not supported yet") + } + + commands, err := spec.Commands() + if err != nil { + return nil, fmt.Errorf("invalid session spec") + } + + return commands, nil +} + +func validateSessionOptions(raw string) error { + if strings.TrimSpace(raw) == "" { + return nil + } + + if _, _, err := config.DeserializeOptions(strings.Split(raw, ":")); err != nil { + return fmt.Errorf("invalid session spec") + } + + return nil } func (s *sessionCommandState) activeSession() bool { @@ -131,6 +210,27 @@ func (s *sessionCommandState) activeSession() bool { return s.active } -func (s *sessionCommandState) advertisedCapabilities() string { - return protocol.HiddenCapabilitiesPrefix + protocol.CapabilityQueryUpdateV1 +func (s *sessionCommandState) keepAlive() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.active +} + +func (s *sessionCommandState) reset() { + s.mu.Lock() + defer s.mu.Unlock() + + s.active = false + s.generation = 0 + s.spec = session.Spec{} + s.cancel = nil +} + +func (h *ServerHandler) dispatchSessionCommands(ctx context.Context, commands []string) error { + for _, command := range commands { + if err := h.handleRawCommand(ctx, command); err != nil { + return err + } + } + return nil } diff --git a/internal/server/handlers/sessioncommand_test.go b/internal/server/handlers/sessioncommand_test.go index 6af8c5b..f4e7e5b 100644 --- a/internal/server/handlers/sessioncommand_test.go +++ b/internal/server/handlers/sessioncommand_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/base64" "encoding/json" + "strings" + "sync" "testing" "time" @@ -47,11 +49,60 @@ func TestHandleSessionCommandStartStoresSpec(t *testing.T) { if !handler.sessionState.activeSession() { t.Fatalf("expected session state to become active") } - if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix { + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { t.Fatalf("unexpected session start message: %q", message) } } +func TestHandleSessionCommandUpdateCancelsPreviousGenerationImmediately(t *testing.T) { + handler, recorder := newSessionDispatchTestHandler("session-update-cancel-user") + readServerMessage(t, handler.serverMessages) + t.Cleanup(func() { + if handler.sessionState.cancel != nil { + handler.sessionState.cancel() + } + recorder.wg.Wait() + }) + + startPayload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app-a.log"}, + Regex: "ERROR", + }) + updatePayload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app-b.log"}, + Regex: "WARN", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", startPayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { + t.Fatalf("unexpected session start ack: %q", message) + } + + first := recorder.waitForStart(t) + if !strings.Contains(first.command, "/var/log/app-a.log") { + t.Fatalf("expected first command to target app-a.log, got %q", first.command) + } + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", updatePayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckUpdateOKPrefix+" 2" { + t.Fatalf("unexpected session update ack: %q", message) + } + + waitForContextDone(t, first.ctx) + + second := recorder.waitForStart(t) + if !strings.Contains(second.command, "/var/log/app-b.log") { + t.Fatalf("expected second command to target app-b.log, got %q", second.command) + } + select { + case <-second.ctx.Done(): + t.Fatalf("expected replacement generation context to remain active") + default: + } +} + func TestHandleSessionCommandUpdateRequiresActiveSession(t *testing.T) { handler := newSessionTestHandler("session-update-user") readServerMessage(t, handler.serverMessages) @@ -81,6 +132,42 @@ func TestHandleSessionCommandRejectsInvalidPayload(t *testing.T) { } } +func TestHandleSessionCommandRejectsUnsupportedQuerySessions(t *testing.T) { + handler := newSessionTestHandler("session-query-user") + readServerMessage(t, handler.serverMessages) + + payload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Query: "from STATS select count(*)", + Regex: ".", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"query sessions not supported yet" { + t.Fatalf("unexpected query-session error: %q", message) + } +} + +func TestHandleSessionCommandRejectsInvalidSerializedOptions(t *testing.T) { + handler := newSessionTestHandler("session-options-user") + readServerMessage(t, handler.serverMessages) + + payload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Options: "badoption", + Regex: "ERROR", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"invalid session spec" { + t.Fatalf("unexpected invalid options error: %q", message) + } +} + func newSessionTestHandler(userName string) *ServerHandler { handler := &ServerHandler{ baseHandler: baseHandler{ @@ -96,10 +183,69 @@ func newSessionTestHandler(userName string) *ServerHandler { AuthKeyEnabled: true, }, } + handler.commands = map[string]commandHandler{ + "tail": immediateNoopCommandHandler, + "cat": immediateNoopCommandHandler, + "grep": immediateNoopCommandHandler, + "map": immediateNoopCommandHandler, + } + handler.handleCommandCb = func(ctx context.Context, ltx lcontext.LContext, argc int, args []string, commandName string) { + if command, found := handler.commands[commandName]; found { + command(ctx, ltx, argc, args, func() {}) + } + } handler.send(handler.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1) return handler } +type recordedCommand struct { + command string + ctx context.Context +} + +type sessionDispatchRecorder struct { + starts chan recordedCommand + wg sync.WaitGroup +} + +func newSessionDispatchTestHandler(userName string) (*ServerHandler, *sessionDispatchRecorder) { + handler := newSessionTestHandler(userName) + recorder := &sessionDispatchRecorder{ + starts: make(chan recordedCommand, 4), + } + handler.commands = map[string]commandHandler{ + "tail": func(ctx context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) { + recorder.starts <- recordedCommand{ + command: strings.Join(args, " "), + ctx: ctx, + } + recorder.wg.Add(1) + go func() { + defer recorder.wg.Done() + <-ctx.Done() + commandFinished() + }() + }, + } + return handler, recorder +} + +func immediateNoopCommandHandler(_ context.Context, _ lcontext.LContext, _ int, _ []string, commandFinished func()) { + commandFinished() +} + +func (r *sessionDispatchRecorder) waitForStart(t *testing.T) recordedCommand { + t.Helper() + + select { + case started := <-r.starts: + return started + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for dispatched session command") + return recordedCommand{} + } +} + func mustSessionPayload(t *testing.T, spec session.Spec) string { t.Helper() @@ -133,24 +279,29 @@ func TestParseSessionCommandWithGeneration(t *testing.T) { } func TestSessionStateStoreUpdateAutoIncrementsGeneration(t *testing.T) { - var state sessionCommandState + handler := newSessionTestHandler("session-generation-user") + readServerMessage(t, handler.serverMessages) - state.storeStart(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/a"}, Regex: "ERROR"}) - state.storeUpdate(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/b"}, Regex: "WARN"}, 0) + startPayload := mustSessionPayload(t, session.Spec{Mode: omode.TailClient, Regex: "ERROR"}) + updatePayload := mustSessionPayload(t, session.Spec{Mode: omode.TailClient, Regex: "WARN"}) - state.mu.Lock() - defer state.mu.Unlock() - if state.generation != 2 { - t.Fatalf("unexpected generation: %d", state.generation) + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", startPayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { + t.Fatalf("unexpected session start ack: %q", message) + } + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", updatePayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckUpdateOKPrefix+" 2" { + t.Fatalf("unexpected session update ack: %q", message) } } -func TestSessionCommandReadServerMessageTimeoutProtection(t *testing.T) { - messages := make(chan string) +func waitForContextDone(t *testing.T, ctx context.Context) { + t.Helper() select { - case <-messages: - t.Fatalf("unexpected message") - case <-time.After(5 * time.Millisecond): + case <-ctx.Done(): + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for context cancellation") } } |
