diff options
Diffstat (limited to 'internal/server/handlers/sessioncommand.go')
| -rw-r--r-- | internal/server/handlers/sessioncommand.go | 128 |
1 files changed, 114 insertions, 14 deletions
diff --git a/internal/server/handlers/sessioncommand.go b/internal/server/handlers/sessioncommand.go index bc5f83e..351f27e 100644 --- a/internal/server/handlers/sessioncommand.go +++ b/internal/server/handlers/sessioncommand.go @@ -9,9 +9,9 @@ import ( "strings" "sync" + "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/lcontext" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/protocol" "github.com/mimecast/dtail/internal/session" ) @@ -26,6 +26,7 @@ type sessionCommandState struct { active bool generation uint64 spec session.Spec + cancel context.CancelFunc } func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) { @@ -39,15 +40,23 @@ func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LCont switch action { case "START": - h.sessionState.storeStart(spec) - h.send(h.serverMessages, sessionAckStartOKPrefix) + generation, err = h.sessionState.start(h, spec) + if err != nil { + h.send(h.serverMessages, sessionAckErrorPrefix+err.Error()) + return + } + h.send(h.serverMessages, fmt.Sprintf("%s %d", sessionAckStartOKPrefix, generation)) case "UPDATE": if !h.sessionState.activeSession() { h.send(h.serverMessages, sessionAckErrorPrefix+"session not started") return } - h.sessionState.storeUpdate(spec, generation) - h.send(h.serverMessages, sessionAckUpdateOKPrefix) + generation, err = h.sessionState.update(h, spec, generation) + if err != nil { + h.send(h.serverMessages, sessionAckErrorPrefix+err.Error()) + return + } + h.send(h.serverMessages, fmt.Sprintf("%s %d", sessionAckUpdateOKPrefix, generation)) default: h.send(h.serverMessages, sessionAckErrorPrefix+"unknown action") } @@ -97,6 +106,10 @@ func validateSessionSpec(spec session.Spec) error { return fmt.Errorf("missing session query") } + if err := validateSessionOptions(spec.Options); err != nil { + return err + } + if _, err := spec.Commands(); err != nil { return fmt.Errorf("invalid session spec") } @@ -104,25 +117,91 @@ func validateSessionSpec(spec session.Spec) error { return nil } -func (s *sessionCommandState) storeStart(spec session.Spec) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *sessionCommandState) start(handler *ServerHandler, spec session.Spec) (uint64, error) { + commands, err := prepareSessionCommands(spec) + if err != nil { + return 0, err + } + s.mu.Lock() + if s.active { + s.mu.Unlock() + return 0, fmt.Errorf("session already started") + } + ctx, cancel := handler.newCommandContext(context.Background()) s.active = true s.generation = 1 s.spec = spec + s.cancel = cancel + s.mu.Unlock() + + if err := handler.dispatchSessionCommands(ctx, commands); err != nil { + cancel() + s.reset() + return 0, err + } + + return 1, nil } -func (s *sessionCommandState) storeUpdate(spec session.Spec, generation uint64) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *sessionCommandState) update(handler *ServerHandler, spec session.Spec, generation uint64) (uint64, error) { + commands, err := prepareSessionCommands(spec) + if err != nil { + return 0, err + } - s.active = true + s.mu.Lock() + if !s.active { + s.mu.Unlock() + return 0, fmt.Errorf("session not started") + } + oldCancel := s.cancel + ctx, cancel := handler.newCommandContext(context.Background()) if generation == 0 { generation = s.generation + 1 } + s.active = true s.generation = generation s.spec = spec + s.cancel = cancel + s.mu.Unlock() + + if oldCancel != nil { + oldCancel() + } + + if err := handler.dispatchSessionCommands(ctx, commands); err != nil { + cancel() + s.reset() + return 0, err + } + + return generation, nil +} + +func prepareSessionCommands(spec session.Spec) ([]string, error) { + if spec.Query != "" { + return nil, fmt.Errorf("query sessions not supported yet") + } + + commands, err := spec.Commands() + if err != nil { + return nil, fmt.Errorf("invalid session spec") + } + + return commands, nil +} + +func validateSessionOptions(raw string) error { + if strings.TrimSpace(raw) == "" { + return nil + } + + if _, _, err := config.DeserializeOptions(strings.Split(raw, ":")); err != nil { + return fmt.Errorf("invalid session spec") + } + + return nil } func (s *sessionCommandState) activeSession() bool { @@ -131,6 +210,27 @@ func (s *sessionCommandState) activeSession() bool { return s.active } -func (s *sessionCommandState) advertisedCapabilities() string { - return protocol.HiddenCapabilitiesPrefix + protocol.CapabilityQueryUpdateV1 +func (s *sessionCommandState) keepAlive() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.active +} + +func (s *sessionCommandState) reset() { + s.mu.Lock() + defer s.mu.Unlock() + + s.active = false + s.generation = 0 + s.spec = session.Spec{} + s.cancel = nil +} + +func (h *ServerHandler) dispatchSessionCommands(ctx context.Context, commands []string) error { + for _, command := range commands { + if err := h.handleRawCommand(ctx, command); err != nil { + return err + } + } + return nil } |
