diff options
Diffstat (limited to 'internal/server/handlers/sessioncommand_test.go')
| -rw-r--r-- | internal/server/handlers/sessioncommand_test.go | 177 |
1 files changed, 164 insertions, 13 deletions
diff --git a/internal/server/handlers/sessioncommand_test.go b/internal/server/handlers/sessioncommand_test.go index 6af8c5b..f4e7e5b 100644 --- a/internal/server/handlers/sessioncommand_test.go +++ b/internal/server/handlers/sessioncommand_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/base64" "encoding/json" + "strings" + "sync" "testing" "time" @@ -47,11 +49,60 @@ func TestHandleSessionCommandStartStoresSpec(t *testing.T) { if !handler.sessionState.activeSession() { t.Fatalf("expected session state to become active") } - if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix { + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { t.Fatalf("unexpected session start message: %q", message) } } +func TestHandleSessionCommandUpdateCancelsPreviousGenerationImmediately(t *testing.T) { + handler, recorder := newSessionDispatchTestHandler("session-update-cancel-user") + readServerMessage(t, handler.serverMessages) + t.Cleanup(func() { + if handler.sessionState.cancel != nil { + handler.sessionState.cancel() + } + recorder.wg.Wait() + }) + + startPayload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app-a.log"}, + Regex: "ERROR", + }) + updatePayload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app-b.log"}, + Regex: "WARN", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", startPayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { + t.Fatalf("unexpected session start ack: %q", message) + } + + first := recorder.waitForStart(t) + if !strings.Contains(first.command, "/var/log/app-a.log") { + t.Fatalf("expected first command to target app-a.log, got %q", first.command) + } + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", updatePayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckUpdateOKPrefix+" 2" { + t.Fatalf("unexpected session update ack: %q", message) + } + + waitForContextDone(t, first.ctx) + + second := recorder.waitForStart(t) + if !strings.Contains(second.command, "/var/log/app-b.log") { + t.Fatalf("expected second command to target app-b.log, got %q", second.command) + } + select { + case <-second.ctx.Done(): + t.Fatalf("expected replacement generation context to remain active") + default: + } +} + func TestHandleSessionCommandUpdateRequiresActiveSession(t *testing.T) { handler := newSessionTestHandler("session-update-user") readServerMessage(t, handler.serverMessages) @@ -81,6 +132,42 @@ func TestHandleSessionCommandRejectsInvalidPayload(t *testing.T) { } } +func TestHandleSessionCommandRejectsUnsupportedQuerySessions(t *testing.T) { + handler := newSessionTestHandler("session-query-user") + readServerMessage(t, handler.serverMessages) + + payload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Query: "from STATS select count(*)", + Regex: ".", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"query sessions not supported yet" { + t.Fatalf("unexpected query-session error: %q", message) + } +} + +func TestHandleSessionCommandRejectsInvalidSerializedOptions(t *testing.T) { + handler := newSessionTestHandler("session-options-user") + readServerMessage(t, handler.serverMessages) + + payload := mustSessionPayload(t, session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Options: "badoption", + Regex: "ERROR", + }) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"invalid session spec" { + t.Fatalf("unexpected invalid options error: %q", message) + } +} + func newSessionTestHandler(userName string) *ServerHandler { handler := &ServerHandler{ baseHandler: baseHandler{ @@ -96,10 +183,69 @@ func newSessionTestHandler(userName string) *ServerHandler { AuthKeyEnabled: true, }, } + handler.commands = map[string]commandHandler{ + "tail": immediateNoopCommandHandler, + "cat": immediateNoopCommandHandler, + "grep": immediateNoopCommandHandler, + "map": immediateNoopCommandHandler, + } + handler.handleCommandCb = func(ctx context.Context, ltx lcontext.LContext, argc int, args []string, commandName string) { + if command, found := handler.commands[commandName]; found { + command(ctx, ltx, argc, args, func() {}) + } + } handler.send(handler.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1) return handler } +type recordedCommand struct { + command string + ctx context.Context +} + +type sessionDispatchRecorder struct { + starts chan recordedCommand + wg sync.WaitGroup +} + +func newSessionDispatchTestHandler(userName string) (*ServerHandler, *sessionDispatchRecorder) { + handler := newSessionTestHandler(userName) + recorder := &sessionDispatchRecorder{ + starts: make(chan recordedCommand, 4), + } + handler.commands = map[string]commandHandler{ + "tail": func(ctx context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) { + recorder.starts <- recordedCommand{ + command: strings.Join(args, " "), + ctx: ctx, + } + recorder.wg.Add(1) + go func() { + defer recorder.wg.Done() + <-ctx.Done() + commandFinished() + }() + }, + } + return handler, recorder +} + +func immediateNoopCommandHandler(_ context.Context, _ lcontext.LContext, _ int, _ []string, commandFinished func()) { + commandFinished() +} + +func (r *sessionDispatchRecorder) waitForStart(t *testing.T) recordedCommand { + t.Helper() + + select { + case started := <-r.starts: + return started + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for dispatched session command") + return recordedCommand{} + } +} + func mustSessionPayload(t *testing.T, spec session.Spec) string { t.Helper() @@ -133,24 +279,29 @@ func TestParseSessionCommandWithGeneration(t *testing.T) { } func TestSessionStateStoreUpdateAutoIncrementsGeneration(t *testing.T) { - var state sessionCommandState + handler := newSessionTestHandler("session-generation-user") + readServerMessage(t, handler.serverMessages) - state.storeStart(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/a"}, Regex: "ERROR"}) - state.storeUpdate(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/b"}, Regex: "WARN"}, 0) + startPayload := mustSessionPayload(t, session.Spec{Mode: omode.TailClient, Regex: "ERROR"}) + updatePayload := mustSessionPayload(t, session.Spec{Mode: omode.TailClient, Regex: "WARN"}) - state.mu.Lock() - defer state.mu.Unlock() - if state.generation != 2 { - t.Fatalf("unexpected generation: %d", state.generation) + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", startPayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" { + t.Fatalf("unexpected session start ack: %q", message) + } + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", updatePayload}, func() {}) + if message := readServerMessage(t, handler.serverMessages); message != sessionAckUpdateOKPrefix+" 2" { + t.Fatalf("unexpected session update ack: %q", message) } } -func TestSessionCommandReadServerMessageTimeoutProtection(t *testing.T) { - messages := make(chan string) +func waitForContextDone(t *testing.T, ctx context.Context) { + t.Helper() select { - case <-messages: - t.Fatalf("unexpected message") - case <-time.After(5 * time.Millisecond): + case <-ctx.Done(): + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for context cancellation") } } |
