diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-13 20:52:54 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-13 20:52:54 +0200 |
| commit | 1b34e1f2501b8def0a0fb4eae28bf6c19a8adde2 (patch) | |
| tree | 4898ab4ff4a7dd4ea102726a845e3935c39ee320 | |
| parent | 07d654f76e1002b6ac18a43aab3c64797dcd2a32 (diff) | |
Fix serverless output draining regressions
| -rw-r--r-- | integrationtests/commandutils.go | 5 | ||||
| -rw-r--r-- | internal/clients/connectors/serverless.go | 26 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler.go | 4 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler_test.go | 32 | ||||
| -rw-r--r-- | internal/clients/session_spec_test.go | 14 | ||||
| -rw-r--r-- | internal/server/handlers/basehandler.go | 25 | ||||
| -rw-r--r-- | internal/server/handlers/readcommand.go | 25 | ||||
| -rw-r--r-- | internal/server/handlers/turbo_writer.go | 15 | ||||
| -rw-r--r-- | internal/server/handlers/turbo_writer_test.go | 55 | ||||
| -rw-r--r-- | internal/session/spec.go | 16 |
10 files changed, 193 insertions, 24 deletions
diff --git a/integrationtests/commandutils.go b/integrationtests/commandutils.go index 6dfe069..86feee2 100644 --- a/integrationtests/commandutils.go +++ b/integrationtests/commandutils.go @@ -2,6 +2,7 @@ package integrationtests import ( "bufio" + "bytes" "context" "fmt" "io" @@ -36,7 +37,9 @@ func runCommand(ctx context.Context, t *testing.T, stdoutFile, cmdStr string, cmd := exec.CommandContext(ctx, cmdStr, args...) out, err := cmd.CombinedOutput() t.Log("Done running command!", err) - _, _ = fd.Write(out) + if _, copyErr := io.Copy(fd, bytes.NewReader(out)); copyErr != nil { + return exitCodeFromError(err), copyErr + } return exitCodeFromError(err), err } diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go index 72e3fda..4e4d57e 100644 --- a/internal/clients/connectors/serverless.go +++ b/internal/clients/connectors/serverless.go @@ -3,6 +3,7 @@ package connectors import ( "context" "io" + "sync" "time" "github.com/mimecast/dtail/internal/clients/handlers" @@ -76,13 +77,16 @@ func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { dlog.Client.Debug("Starting serverless connector") + done := make(chan struct{}) go func() { + defer close(done) defer cancel() if err := s.handle(ctx, cancel); err != nil { dlog.Client.Warn(err) } }() <-ctx.Done() + <-done } func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) error { @@ -111,9 +115,12 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro // Error tracking errChan := make(chan error, 4) + var ioWg sync.WaitGroup // Read from client handler + ioWg.Add(1) go func() { + defer ioWg.Done() defer close(toServer) buf := make([]byte, 32*1024) for { @@ -137,7 +144,9 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro }() // Write to server handler + ioWg.Add(1) go func() { + defer ioWg.Done() for data := range toServer { if _, err := serverHandler.Write(data); err != nil { errChan <- err @@ -147,7 +156,9 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro }() // Read from server handler + ioWg.Add(1) go func() { + defer ioWg.Done() defer close(fromServer) buf := make([]byte, 64*1024) // Larger buffer for server responses for { @@ -172,7 +183,9 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro // Write to client handler serverDone := make(chan struct{}) + ioWg.Add(1) go func() { + defer ioWg.Done() defer close(serverDone) for data := range fromServer { if _, err := s.handler.Write(data); err != nil { @@ -192,6 +205,18 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro select { case <-s.handler.Done(): dlog.Client.Trace("<-s.handler.Done()") + // The client handler marks itself done as soon as it receives the + // hidden close message. Keep the in-process server alive long enough + // for the remaining output and close ACK to drain instead of canceling + // the whole session immediately. + select { + case <-serverDone: + dlog.Client.Trace("Server transfer done after client close") + case <-ctx.Done(): + dlog.Client.Trace("<-ctx.Done() while waiting for server transfer") + case <-time.After(6 * time.Second): + dlog.Client.Debug("Timed out waiting for server transfer after client close") + } case <-serverDone: dlog.Client.Trace("Server transfer done") case <-ctx.Done(): @@ -201,6 +226,7 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro // Wait for completion <-ctx.Done() + ioWg.Wait() // Check for errors select { diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 8da4556..2979091 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -192,7 +192,9 @@ func (h *baseHandler) handleHiddenMessage(message string) { strings.HasPrefix(message, protocol.HiddenSessionErrorPrefix): h.handleSessionAckMessage(message) case strings.HasPrefix(message, ".syn close connection"): - go h.SendMessage(".ack close connection") + if err := h.SendMessage(".ack close connection"); err != nil { + dlog.Client.Debug(h.server, "Unable to acknowledge close connection", err) + } h.Shutdown() } } diff --git a/internal/clients/handlers/basehandler_test.go b/internal/clients/handlers/basehandler_test.go index 7db2bb8..3e8aaa1 100644 --- a/internal/clients/handlers/basehandler_test.go +++ b/internal/clients/handlers/basehandler_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/protocol" ) @@ -172,3 +173,34 @@ func TestHandleSessionAckMessage(t *testing.T) { t.Fatalf("unexpected session ack: %#v", ack) } } + +func TestHandleCloseConnectionAcknowledgesBeforeShutdown(t *testing.T) { + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + dlog.Client = originalLogger + }) + + handler := baseHandler{ + done: internal.NewDone(), + server: "server-under-test", + commands: make(chan string, 1), + } + + handler.handleHiddenMessage(".syn close connection") + + select { + case command := <-handler.commands: + if command == "" { + t.Fatal("expected close acknowledgement command") + } + case <-time.After(10 * time.Millisecond): + t.Fatal("expected close acknowledgement command to be queued") + } + + select { + case <-handler.Done(): + default: + t.Fatal("expected handler to be shut down after close acknowledgement") + } +} diff --git a/internal/clients/session_spec_test.go b/internal/clients/session_spec_test.go index aa3c45d..8133bc9 100644 --- a/internal/clients/session_spec_test.go +++ b/internal/clients/session_spec_test.go @@ -131,3 +131,17 @@ func TestNewSessionSpecSplitsFiles(t *testing.T) { t.Fatalf("unexpected timeout: %d", spec.Timeout) } } + +func TestNewSessionSpecUsesPipeSentinelForServerlessStdin(t *testing.T) { + t.Parallel() + + spec := NewSessionSpec(config.Args{ + Mode: omode.GrepClient, + Serverless: true, + RegexStr: "ERROR", + }) + + if len(spec.Files) != 1 || spec.Files[0] != "-" { + t.Fatalf("unexpected files for serverless stdin: %#v", spec.Files) + } +} diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go index 66c2cb7..06943b3 100644 --- a/internal/server/handlers/basehandler.go +++ b/internal/server/handlers/basehandler.go @@ -339,25 +339,28 @@ func (h *baseHandler) flush() { return lineCount + serverCount + maprCount + turboCount } - // Increase iterations for turbo mode to handle large file batches - maxIterations := 100 - if h.turbo.enabled() { - maxIterations = 300 // Give more time for turbo mode to drain + maxWait := time.Second + if h.turbo.enabled() || h.turboAggregate != nil || h.aggregate != nil { + maxWait = 3 * time.Second } - // Also increase iterations if we have MapReduce messages - if h.turboAggregate != nil || h.aggregate != nil { - maxIterations = 300 // Give more time for MapReduce results + if h.serverless && maxWait < 5*time.Second { + maxWait = 5 * time.Second } - for i := 0; i < maxIterations; i++ { - if numUnsentMessages() == 0 { + deadline := time.Now().Add(maxWait) + for i := 0; ; i++ { + unsent := numUnsentMessages() + if unsent == 0 { dlog.Server.Debug(h.user, "ALL lines sent", fmt.Sprintf("%p", h)) return } - dlog.Server.Debug(h.user, "Still lines to be sent", "iteration", i, "unsent", numUnsentMessages()) + if time.Now().After(deadline) { + dlog.Server.Warn(h.user, "Some lines remain unsent", unsent) + return + } + dlog.Server.Debug(h.user, "Still lines to be sent", "iteration", i, "unsent", unsent, "deadline", deadline.Sub(time.Now())) time.Sleep(time.Millisecond * 10) } - dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages()) } func (h *baseHandler) shutdown() { diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go index 9677718..d4c9c30 100644 --- a/internal/server/handlers/readcommand.go +++ b/internal/server/handlers/readcommand.go @@ -328,23 +328,24 @@ func (r *readCommand) executeReadLoop(ctx context.Context, ltx lcontext.LContext func (r *readCommand) readViaChannels() readStrategy { return func(ctx context.Context, ltx lcontext.LContext, reader fs.FileReader, re regex.Regex) error { var linesCh chan *line.Line - closeLines := false + var closeLines func() if r.server.HasRegularAggregate() { // For MapReduce operations, create a new channel that goes only to the aggregate. linesCh = make(chan *line.Line, r.server.AggregateLinesChannelBufferSize()) r.server.RegisterAggregateLines(linesCh) - closeLines = true + closeLines = func() { + close(linesCh) + } } else { // For non-MapReduce operations, forward lines through a generation-aware channel. - linesCh = r.newGeneratedLinesChannel(ctx) - closeLines = true + linesCh, closeLines = r.newGeneratedLinesChannel(ctx) } err := reader.Start(ctx, ltx, linesCh, re) - if closeLines { + if closeLines != nil { // Closing the aggregate line channel triggers flush. - close(linesCh) + closeLines() } return err @@ -463,7 +464,9 @@ func (r *readCommand) sendServerMessage(message string) { func (r *readCommand) newGeneratedServerMessagesChannel(ctx context.Context) (chan string, func()) { serverMessages := make(chan string, 16) + done := make(chan struct{}) go func() { + defer close(done) for { select { case message, ok := <-serverMessages: @@ -482,12 +485,15 @@ func (r *readCommand) newGeneratedServerMessagesChannel(ctx context.Context) (ch }() return serverMessages, func() { close(serverMessages) + <-done } } -func (r *readCommand) newGeneratedLinesChannel(ctx context.Context) chan *line.Line { +func (r *readCommand) newGeneratedLinesChannel(ctx context.Context) (chan *line.Line, func()) { linesCh := make(chan *line.Line, r.server.AggregateLinesChannelBufferSize()) + done := make(chan struct{}) go func() { + defer close(done) for { select { case generatedLine, ok := <-linesCh: @@ -512,7 +518,10 @@ func (r *readCommand) newGeneratedLinesChannel(ctx context.Context) chan *line.L } } }() - return linesCh + return linesCh, func() { + close(linesCh) + <-done + } } func (r *readCommand) isInputFromPipe() bool { diff --git a/internal/server/handlers/turbo_writer.go b/internal/server/handlers/turbo_writer.go index f09a2af..3cd347b 100644 --- a/internal/server/handlers/turbo_writer.go +++ b/internal/server/handlers/turbo_writer.go @@ -205,10 +205,21 @@ func (w *DirectTurboWriter) flushBuffer() error { // In serverless mode with colors, data is already processed line by line // so we don't need to do any additional formatting here - _, err := w.writer.Write(data) + for len(data) > 0 { + n, err := w.writer.Write(data) + if err != nil { + w.writeBuf.Reset() + return err + } + if n <= 0 { + w.writeBuf.Reset() + return io.ErrShortWrite + } + data = data[n:] + } w.writeBuf.Reset() - return err + return nil } // Stats returns writing statistics diff --git a/internal/server/handlers/turbo_writer_test.go b/internal/server/handlers/turbo_writer_test.go index 23a07d4..13460a5 100644 --- a/internal/server/handlers/turbo_writer_test.go +++ b/internal/server/handlers/turbo_writer_test.go @@ -2,6 +2,7 @@ package handlers import ( "bytes" + "io" "strings" "sync/atomic" "testing" @@ -221,6 +222,60 @@ func TestDirectTurboWriter_MultipleLines(t *testing.T) { } } +type shortWriter struct { + maxChunk int + buf bytes.Buffer +} + +func (w *shortWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + n := len(p) + if w.maxChunk > 0 && n > w.maxChunk { + n = w.maxChunk + } + w.buf.Write(p[:n]) + return n, nil +} + +func TestDirectTurboWriter_FlushHandlesShortWrites(t *testing.T) { + writer := &shortWriter{maxChunk: 5} + w := NewDirectTurboWriter(writer, "testhost", true, true) + + if err := w.WriteLineData([]byte("abcdefghij"), 1, "source.log"); err != nil { + t.Fatalf("WriteLineData failed: %v", err) + } + + if err := w.Flush(); err != nil { + t.Fatalf("Flush failed: %v", err) + } + + if got, want := writer.buf.String(), "abcdefghij\n"; got != want { + t.Fatalf("expected full output %q, got %q", want, got) + } +} + +type zeroWriter struct{} + +func (zeroWriter) Write(p []byte) (int, error) { + return 0, nil +} + +func TestDirectTurboWriter_FlushFailsOnZeroProgress(t *testing.T) { + w := NewDirectTurboWriter(zeroWriter{}, "testhost", true, true) + + if err := w.WriteLineData([]byte("data"), 1, "source.log"); err != nil { + t.Fatalf("WriteLineData failed: %v", err) + } + + if err := w.Flush(); err == nil { + t.Fatal("expected Flush to fail on zero-progress writes") + } else if err != io.ErrShortWrite { + t.Fatalf("expected io.ErrShortWrite, got %v", err) + } +} + // TestTurboChannelWriter_WriteLineData tests channel writer line data func TestTurboChannelWriter_WriteLineData(t *testing.T) { ch := make(chan []byte, 10) diff --git a/internal/session/spec.go b/internal/session/spec.go index 0a6ad4e..6df11fd 100644 --- a/internal/session/spec.go +++ b/internal/session/spec.go @@ -24,9 +24,14 @@ type Spec struct { // NewSpec returns a session specification from client args. func NewSpec(args config.Args) Spec { + files := splitFiles(args.What) + if args.Serverless && len(files) == 0 && supportsServerlessPipe(args.Mode) { + files = []string{"-"} + } + return Spec{ Mode: args.Mode, - Files: splitFiles(args.What), + Files: files, Options: args.SerializeOptions(), Query: strings.TrimSpace(args.QueryStr), Regex: args.RegexStr, @@ -149,6 +154,15 @@ func splitFiles(what string) []string { return files } +func supportsServerlessPipe(mode omode.Mode) bool { + switch mode { + case omode.TailClient, omode.CatClient, omode.GrepClient, omode.MapClient: + return true + default: + return false + } +} + func (s Spec) encodedPayload() (string, error) { payload, err := json.Marshal(s) if err != nil { |
