diff options
Diffstat (limited to 'internal/server/handlers/sessioncommand_test.go')
| -rw-r--r-- | internal/server/handlers/sessioncommand_test.go | 147 |
1 files changed, 143 insertions, 4 deletions
diff --git a/internal/server/handlers/sessioncommand_test.go b/internal/server/handlers/sessioncommand_test.go index f4e7e5b..d7df000 100644 --- a/internal/server/handlers/sessioncommand_test.go +++ b/internal/server/handlers/sessioncommand_test.go @@ -11,8 +11,10 @@ import ( "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/lcontext" + maprserver "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/protocol" "github.com/mimecast/dtail/internal/session" @@ -132,8 +134,8 @@ func TestHandleSessionCommandRejectsInvalidPayload(t *testing.T) { } } -func TestHandleSessionCommandRejectsUnsupportedQuerySessions(t *testing.T) { - handler := newSessionTestHandler("session-query-user") +func TestHandleSessionCommandStartDispatchesQueryWorkload(t *testing.T) { + handler, recorder := newQuerySessionDispatchTestHandler("session-query-user") readServerMessage(t, handler.serverMessages) payload := mustSessionPayload(t, session.Spec{ @@ -145,8 +147,24 @@ func TestHandleSessionCommandRejectsUnsupportedQuerySessions(t *testing.T) { handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {}) - if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"query sessions not supported yet" { - t.Fatalf("unexpected query-session error: %q", message) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { + t.Fatalf("unexpected query-session ack: %q", message) + } + + first := recorder.waitForStart(t) + if !strings.HasPrefix(first.command, "map:") { + t.Fatalf("expected map command first, got %q", first.command) + } + if !strings.Contains(first.command, "from STATS select count(*)") { + t.Fatalf("expected map command to contain query, got %q", first.command) + } + + second := recorder.waitForStart(t) + if !strings.HasPrefix(second.command, "tail:") { + t.Fatalf("expected tail command second, got %q", second.command) + } + if !strings.Contains(second.command, "/var/log/app.log") { + t.Fatalf("expected tail command to contain file, got %q", second.command) } } @@ -168,6 +186,86 @@ func TestHandleSessionCommandRejectsInvalidSerializedOptions(t *testing.T) { } } +func TestHandleSessionCommandRejectsInvalidQuerySession(t *testing.T) { + handler := newSessionTestHandler("session-invalid-query-user") + readServerMessage(t, handler.serverMessages) + + payload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Query: "select from", + Regex: ".", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"invalid session spec" { + t.Fatalf("unexpected invalid query-session error: %q", message) + } +} + +func TestHandleSessionCommandUpdateClearsAggregateStateBeforeDirectRead(t *testing.T) { + resetServerLogger(t) + + handler := newSessionTestHandler("session-query-reset-user") + readServerMessage(t, handler.serverMessages) + + sawResetState := make(chan bool, 1) + tailCalls := 0 + handler.commands = map[string]commandHandler{ + "map": func(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) { + queryStr := strings.Join(args[1:], " ") + aggregate, err := maprserver.NewAggregate(queryStr, "") + if err != nil { + t.Fatalf("new aggregate: %v", err) + } + handler.aggregate = aggregate + commandFinished() + }, + "tail": func(_ context.Context, _ lcontext.LContext, _ int, _ []string, commandFinished func()) { + tailCalls++ + if tailCalls > 1 { + sawResetState <- handler.aggregate == nil && handler.turboAggregate == nil + } + commandFinished() + }, + } + + startPayload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app-a.log"}, + Query: "from STATS select count(*)", + Regex: ".", + }) + updatePayload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app-b.log"}, + Regex: "WARN", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", startPayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { + t.Fatalf("unexpected session start ack: %q", message) + } + if handler.aggregate == nil { + t.Fatalf("expected query session to install aggregate state") + } + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", updatePayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckUpdateOKPrefix+" 2" { + t.Fatalf("unexpected session update ack: %q", message) + } + + select { + case ok := <-sawResetState: + if !ok { + t.Fatalf("expected aggregate state to be cleared before direct read dispatch") + } + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for direct read dispatch") + } +} + func newSessionTestHandler(userName string) *ServerHandler { handler := &ServerHandler{ baseHandler: baseHandler{ @@ -230,6 +328,37 @@ func newSessionDispatchTestHandler(userName string) (*ServerHandler, *sessionDis return handler, recorder } +func newQuerySessionDispatchTestHandler(userName string) (*ServerHandler, *sessionDispatchRecorder) { + handler := newSessionTestHandler(userName) + recorder := &sessionDispatchRecorder{ + starts: make(chan recordedCommand, 8), + } + handler.commands = map[string]commandHandler{ + "map": func(ctx context.Context, _ lcontext.LContext, _ int, args []string, commandFinished func()) { + recorder.starts <- recordedCommand{ + command: strings.Join(args, " "), + ctx: ctx, + } + commandFinished() + }, + "tail": func(ctx context.Context, _ lcontext.LContext, _ int, args []string, commandFinished func()) { + recorder.starts <- recordedCommand{ + command: strings.Join(args, " "), + ctx: ctx, + } + commandFinished() + }, + "cat": func(ctx context.Context, _ lcontext.LContext, _ int, args []string, commandFinished func()) { + recorder.starts <- recordedCommand{ + command: strings.Join(args, " "), + ctx: ctx, + } + commandFinished() + }, + } + return handler, recorder +} + func immediateNoopCommandHandler(_ context.Context, _ lcontext.LContext, _ int, _ []string, commandFinished func()) { commandFinished() } @@ -305,3 +434,13 @@ func waitForContextDone(t *testing.T, ctx context.Context) { t.Fatal("timed out waiting for context cancellation") } } + +func resetServerLogger(t *testing.T) { + t.Helper() + + originalLogger := dlog.Server + dlog.Server = &dlog.DLog{} + t.Cleanup(func() { + dlog.Server = originalLogger + }) +} |
