summaryrefslogtreecommitdiff
path: root/internal/server/handlers/sessioncommand_test.go
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-13 07:59:45 +0200
committerPaul Buetow <paul@buetow.org>2026-03-13 07:59:45 +0200
commit6dbc03d5c7b6068665e2d95bc66c4f3700323dc8 (patch)
treedc14218fd578caca6b0a7ada3ceb1a0f060a9a9e /internal/server/handlers/sessioncommand_test.go
parentc88dddee1953c938b47830ec13696f23770eb22d (diff)
task 398: implement session preemption
Diffstat (limited to 'internal/server/handlers/sessioncommand_test.go')
-rw-r--r--internal/server/handlers/sessioncommand_test.go177
1 files changed, 164 insertions, 13 deletions
diff --git a/internal/server/handlers/sessioncommand_test.go b/internal/server/handlers/sessioncommand_test.go
index 6af8c5b..f4e7e5b 100644
--- a/internal/server/handlers/sessioncommand_test.go
+++ b/internal/server/handlers/sessioncommand_test.go
@@ -4,6 +4,8 @@ import (
"context"
"encoding/base64"
"encoding/json"
+ "strings"
+ "sync"
"testing"
"time"
@@ -47,11 +49,60 @@ func TestHandleSessionCommandStartStoresSpec(t *testing.T) {
if !handler.sessionState.activeSession() {
t.Fatalf("expected session state to become active")
}
- if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix {
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" {
t.Fatalf("unexpected session start message: %q", message)
}
}
+func TestHandleSessionCommandUpdateCancelsPreviousGenerationImmediately(t *testing.T) {
+ handler, recorder := newSessionDispatchTestHandler("session-update-cancel-user")
+ readServerMessage(t, handler.serverMessages)
+ t.Cleanup(func() {
+ if handler.sessionState.cancel != nil {
+ handler.sessionState.cancel()
+ }
+ recorder.wg.Wait()
+ })
+
+ startPayload := mustSessionPayload(t, session.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app-a.log"},
+ Regex: "ERROR",
+ })
+ updatePayload := mustSessionPayload(t, session.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app-b.log"},
+ Regex: "WARN",
+ })
+
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", startPayload}, func() {})
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" {
+ t.Fatalf("unexpected session start ack: %q", message)
+ }
+
+ first := recorder.waitForStart(t)
+ if !strings.Contains(first.command, "/var/log/app-a.log") {
+ t.Fatalf("expected first command to target app-a.log, got %q", first.command)
+ }
+
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", updatePayload}, func() {})
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckUpdateOKPrefix+" 2" {
+ t.Fatalf("unexpected session update ack: %q", message)
+ }
+
+ waitForContextDone(t, first.ctx)
+
+ second := recorder.waitForStart(t)
+ if !strings.Contains(second.command, "/var/log/app-b.log") {
+ t.Fatalf("expected second command to target app-b.log, got %q", second.command)
+ }
+ select {
+ case <-second.ctx.Done():
+ t.Fatalf("expected replacement generation context to remain active")
+ default:
+ }
+}
+
func TestHandleSessionCommandUpdateRequiresActiveSession(t *testing.T) {
handler := newSessionTestHandler("session-update-user")
readServerMessage(t, handler.serverMessages)
@@ -81,6 +132,42 @@ func TestHandleSessionCommandRejectsInvalidPayload(t *testing.T) {
}
}
+func TestHandleSessionCommandRejectsUnsupportedQuerySessions(t *testing.T) {
+ handler := newSessionTestHandler("session-query-user")
+ readServerMessage(t, handler.serverMessages)
+
+ payload := mustSessionPayload(t, session.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Query: "from STATS select count(*)",
+ Regex: ".",
+ })
+
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {})
+
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"query sessions not supported yet" {
+ t.Fatalf("unexpected query-session error: %q", message)
+ }
+}
+
+func TestHandleSessionCommandRejectsInvalidSerializedOptions(t *testing.T) {
+ handler := newSessionTestHandler("session-options-user")
+ readServerMessage(t, handler.serverMessages)
+
+ payload := mustSessionPayload(t, session.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Options: "badoption",
+ Regex: "ERROR",
+ })
+
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {})
+
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"invalid session spec" {
+ t.Fatalf("unexpected invalid options error: %q", message)
+ }
+}
+
func newSessionTestHandler(userName string) *ServerHandler {
handler := &ServerHandler{
baseHandler: baseHandler{
@@ -96,10 +183,69 @@ func newSessionTestHandler(userName string) *ServerHandler {
AuthKeyEnabled: true,
},
}
+ handler.commands = map[string]commandHandler{
+ "tail": immediateNoopCommandHandler,
+ "cat": immediateNoopCommandHandler,
+ "grep": immediateNoopCommandHandler,
+ "map": immediateNoopCommandHandler,
+ }
+ handler.handleCommandCb = func(ctx context.Context, ltx lcontext.LContext, argc int, args []string, commandName string) {
+ if command, found := handler.commands[commandName]; found {
+ command(ctx, ltx, argc, args, func() {})
+ }
+ }
handler.send(handler.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1)
return handler
}
+type recordedCommand struct {
+ command string
+ ctx context.Context
+}
+
+type sessionDispatchRecorder struct {
+ starts chan recordedCommand
+ wg sync.WaitGroup
+}
+
+func newSessionDispatchTestHandler(userName string) (*ServerHandler, *sessionDispatchRecorder) {
+ handler := newSessionTestHandler(userName)
+ recorder := &sessionDispatchRecorder{
+ starts: make(chan recordedCommand, 4),
+ }
+ handler.commands = map[string]commandHandler{
+ "tail": func(ctx context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) {
+ recorder.starts <- recordedCommand{
+ command: strings.Join(args, " "),
+ ctx: ctx,
+ }
+ recorder.wg.Add(1)
+ go func() {
+ defer recorder.wg.Done()
+ <-ctx.Done()
+ commandFinished()
+ }()
+ },
+ }
+ return handler, recorder
+}
+
+func immediateNoopCommandHandler(_ context.Context, _ lcontext.LContext, _ int, _ []string, commandFinished func()) {
+ commandFinished()
+}
+
+func (r *sessionDispatchRecorder) waitForStart(t *testing.T) recordedCommand {
+ t.Helper()
+
+ select {
+ case started := <-r.starts:
+ return started
+ case <-time.After(250 * time.Millisecond):
+ t.Fatal("timed out waiting for dispatched session command")
+ return recordedCommand{}
+ }
+}
+
func mustSessionPayload(t *testing.T, spec session.Spec) string {
t.Helper()
@@ -133,24 +279,29 @@ func TestParseSessionCommandWithGeneration(t *testing.T) {
}
func TestSessionStateStoreUpdateAutoIncrementsGeneration(t *testing.T) {
- var state sessionCommandState
+ handler := newSessionTestHandler("session-generation-user")
+ readServerMessage(t, handler.serverMessages)
- state.storeStart(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/a"}, Regex: "ERROR"})
- state.storeUpdate(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/b"}, Regex: "WARN"}, 0)
+ startPayload := mustSessionPayload(t, session.Spec{Mode: omode.TailClient, Regex: "ERROR"})
+ updatePayload := mustSessionPayload(t, session.Spec{Mode: omode.TailClient, Regex: "WARN"})
- state.mu.Lock()
- defer state.mu.Unlock()
- if state.generation != 2 {
- t.Fatalf("unexpected generation: %d", state.generation)
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", startPayload}, func() {})
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix+" 1" {
+ t.Fatalf("unexpected session start ack: %q", message)
+ }
+
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", updatePayload}, func() {})
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckUpdateOKPrefix+" 2" {
+ t.Fatalf("unexpected session update ack: %q", message)
}
}
-func TestSessionCommandReadServerMessageTimeoutProtection(t *testing.T) {
- messages := make(chan string)
+func waitForContextDone(t *testing.T, ctx context.Context) {
+ t.Helper()
select {
- case <-messages:
- t.Fatalf("unexpected message")
- case <-time.After(5 * time.Millisecond):
+ case <-ctx.Done():
+ case <-time.After(250 * time.Millisecond):
+ t.Fatal("timed out waiting for context cancellation")
}
}