From 9f6850fc202e048dcdbfa6ffb59589d4a851cd84 Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Fri, 13 Mar 2026 08:57:01 +0200 Subject: task 58076a44: enable query session workloads --- internal/server/handlers/sessioncommand.go | 23 +++- internal/server/handlers/sessioncommand_test.go | 147 +++++++++++++++++++++++- 2 files changed, 162 insertions(+), 8 deletions(-) diff --git a/internal/server/handlers/sessioncommand.go b/internal/server/handlers/sessioncommand.go index 351f27e..0d54963 100644 --- a/internal/server/handlers/sessioncommand.go +++ b/internal/server/handlers/sessioncommand.go @@ -11,6 +11,7 @@ import ( "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/lcontext" + "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/session" ) @@ -105,6 +106,11 @@ func validateSessionSpec(spec session.Spec) error { if spec.Query == "" && spec.Mode == omode.MapClient { return fmt.Errorf("missing session query") } + if spec.Query != "" { + if _, err := mapr.NewQuery(spec.Query); err != nil { + return fmt.Errorf("invalid session spec") + } + } if err := validateSessionOptions(spec.Options); err != nil { return err @@ -135,6 +141,7 @@ func (s *sessionCommandState) start(handler *ServerHandler, spec session.Spec) ( s.cancel = cancel s.mu.Unlock() + handler.resetSessionAggregates() if err := handler.dispatchSessionCommands(ctx, commands); err != nil { cancel() s.reset() @@ -170,6 +177,7 @@ func (s *sessionCommandState) update(handler *ServerHandler, spec session.Spec, oldCancel() } + handler.resetSessionAggregates() if err := handler.dispatchSessionCommands(ctx, commands); err != nil { cancel() s.reset() @@ -180,10 +188,6 @@ func (s *sessionCommandState) update(handler *ServerHandler, spec session.Spec, } func prepareSessionCommands(spec session.Spec) ([]string, error) { - if spec.Query != "" { - return nil, fmt.Errorf("query sessions not supported yet") - } - commands, err := spec.Commands() if err != nil { return nil, fmt.Errorf("invalid session spec") @@ -234,3 +238,14 @@ func (h *ServerHandler) dispatchSessionCommands(ctx context.Context, commands [] } return nil } + +func (h *ServerHandler) resetSessionAggregates() { + if h.aggregate != nil { + h.aggregate.Shutdown() + h.aggregate = nil + } + if h.turboAggregate != nil { + h.turboAggregate.Shutdown() + h.turboAggregate = nil + } +} diff --git a/internal/server/handlers/sessioncommand_test.go b/internal/server/handlers/sessioncommand_test.go index f4e7e5b..d7df000 100644 --- a/internal/server/handlers/sessioncommand_test.go +++ b/internal/server/handlers/sessioncommand_test.go @@ -11,8 +11,10 @@ import ( "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/lcontext" + maprserver "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/protocol" "github.com/mimecast/dtail/internal/session" @@ -132,8 +134,8 @@ func TestHandleSessionCommandRejectsInvalidPayload(t *testing.T) { } } -func TestHandleSessionCommandRejectsUnsupportedQuerySessions(t *testing.T) { - handler := newSessionTestHandler("session-query-user") +func TestHandleSessionCommandStartDispatchesQueryWorkload(t *testing.T) { + handler, recorder := newQuerySessionDispatchTestHandler("session-query-user") readServerMessage(t, handler.serverMessages) payload := mustSessionPayload(t, session.Spec{ @@ -145,8 +147,24 @@ func TestHandleSessionCommandRejectsUnsupportedQuerySessions(t *testing.T) { handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {}) - if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"query sessions not supported yet" { - t.Fatalf("unexpected query-session error: %q", message) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { + t.Fatalf("unexpected query-session ack: %q", message) + } + + first := recorder.waitForStart(t) + if !strings.HasPrefix(first.command, "map:") { + t.Fatalf("expected map command first, got %q", first.command) + } + if !strings.Contains(first.command, "from STATS select count(*)") { + t.Fatalf("expected map command to contain query, got %q", first.command) + } + + second := recorder.waitForStart(t) + if !strings.HasPrefix(second.command, "tail:") { + t.Fatalf("expected tail command second, got %q", second.command) + } + if !strings.Contains(second.command, "/var/log/app.log") { + t.Fatalf("expected tail command to contain file, got %q", second.command) } } @@ -168,6 +186,86 @@ func TestHandleSessionCommandRejectsInvalidSerializedOptions(t *testing.T) { } } +func TestHandleSessionCommandRejectsInvalidQuerySession(t *testing.T) { + handler := newSessionTestHandler("session-invalid-query-user") + readServerMessage(t, handler.serverMessages) + + payload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Query: "select from", + Regex: ".", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"invalid session spec" { + t.Fatalf("unexpected invalid query-session error: %q", message) + } +} + +func TestHandleSessionCommandUpdateClearsAggregateStateBeforeDirectRead(t *testing.T) { + resetServerLogger(t) + + handler := newSessionTestHandler("session-query-reset-user") + readServerMessage(t, handler.serverMessages) + + sawResetState := make(chan bool, 1) + tailCalls := 0 + handler.commands = map[string]commandHandler{ + "map": func(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) { + queryStr := strings.Join(args[1:], " ") + aggregate, err := maprserver.NewAggregate(queryStr, "") + if err != nil { + t.Fatalf("new aggregate: %v", err) + } + handler.aggregate = aggregate + commandFinished() + }, + "tail": func(_ context.Context, _ lcontext.LContext, _ int, _ []string, commandFinished func()) { + tailCalls++ + if tailCalls > 1 { + sawResetState <- handler.aggregate == nil && handler.turboAggregate == nil + } + commandFinished() + }, + } + + startPayload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app-a.log"}, + Query: "from STATS select count(*)", + Regex: ".", + }) + updatePayload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app-b.log"}, + Regex: "WARN", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", startPayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { + t.Fatalf("unexpected session start ack: %q", message) + } + if handler.aggregate == nil { + t.Fatalf("expected query session to install aggregate state") + } + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", updatePayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckUpdateOKPrefix+" 2" { + t.Fatalf("unexpected session update ack: %q", message) + } + + select { + case ok := <-sawResetState: + if !ok { + t.Fatalf("expected aggregate state to be cleared before direct read dispatch") + } + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for direct read dispatch") + } +} + func newSessionTestHandler(userName string) *ServerHandler { handler := &ServerHandler{ baseHandler: baseHandler{ @@ -230,6 +328,37 @@ func newSessionDispatchTestHandler(userName string) (*ServerHandler, *sessionDis return handler, recorder } +func newQuerySessionDispatchTestHandler(userName string) (*ServerHandler, *sessionDispatchRecorder) { + handler := newSessionTestHandler(userName) + recorder := &sessionDispatchRecorder{ + starts: make(chan recordedCommand, 8), + } + handler.commands = map[string]commandHandler{ + "map": func(ctx context.Context, _ lcontext.LContext, _ int, args []string, commandFinished func()) { + recorder.starts <- recordedCommand{ + command: strings.Join(args, " "), + ctx: ctx, + } + commandFinished() + }, + "tail": func(ctx context.Context, _ lcontext.LContext, _ int, args []string, commandFinished func()) { + recorder.starts <- recordedCommand{ + command: strings.Join(args, " "), + ctx: ctx, + } + commandFinished() + }, + "cat": func(ctx context.Context, _ lcontext.LContext, _ int, args []string, commandFinished func()) { + recorder.starts <- recordedCommand{ + command: strings.Join(args, " "), + ctx: ctx, + } + commandFinished() + }, + } + return handler, recorder +} + func immediateNoopCommandHandler(_ context.Context, _ lcontext.LContext, _ int, _ []string, commandFinished func()) { commandFinished() } @@ -305,3 +434,13 @@ func waitForContextDone(t *testing.T, ctx context.Context) { t.Fatal("timed out waiting for context cancellation") } } + +func resetServerLogger(t *testing.T) { + t.Helper() + + originalLogger := dlog.Server + dlog.Server = &dlog.DLog{} + t.Cleanup(func() { + dlog.Server = originalLogger + }) +} -- cgit v1.2.3