summaryrefslogtreecommitdiff
path: root/internal/server/handlers/sessioncommand.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/handlers/sessioncommand.go')
-rw-r--r--internal/server/handlers/sessioncommand.go128
1 files changed, 114 insertions, 14 deletions
diff --git a/internal/server/handlers/sessioncommand.go b/internal/server/handlers/sessioncommand.go
index bc5f83e..351f27e 100644
--- a/internal/server/handlers/sessioncommand.go
+++ b/internal/server/handlers/sessioncommand.go
@@ -9,9 +9,9 @@ import (
"strings"
"sync"
+ "github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/lcontext"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/protocol"
"github.com/mimecast/dtail/internal/session"
)
@@ -26,6 +26,7 @@ type sessionCommandState struct {
active bool
generation uint64
spec session.Spec
+ cancel context.CancelFunc
}
func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) {
@@ -39,15 +40,23 @@ func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LCont
switch action {
case "START":
- h.sessionState.storeStart(spec)
- h.send(h.serverMessages, sessionAckStartOKPrefix)
+ generation, err = h.sessionState.start(h, spec)
+ if err != nil {
+ h.send(h.serverMessages, sessionAckErrorPrefix+err.Error())
+ return
+ }
+ h.send(h.serverMessages, fmt.Sprintf("%s %d", sessionAckStartOKPrefix, generation))
case "UPDATE":
if !h.sessionState.activeSession() {
h.send(h.serverMessages, sessionAckErrorPrefix+"session not started")
return
}
- h.sessionState.storeUpdate(spec, generation)
- h.send(h.serverMessages, sessionAckUpdateOKPrefix)
+ generation, err = h.sessionState.update(h, spec, generation)
+ if err != nil {
+ h.send(h.serverMessages, sessionAckErrorPrefix+err.Error())
+ return
+ }
+ h.send(h.serverMessages, fmt.Sprintf("%s %d", sessionAckUpdateOKPrefix, generation))
default:
h.send(h.serverMessages, sessionAckErrorPrefix+"unknown action")
}
@@ -97,6 +106,10 @@ func validateSessionSpec(spec session.Spec) error {
return fmt.Errorf("missing session query")
}
+ if err := validateSessionOptions(spec.Options); err != nil {
+ return err
+ }
+
if _, err := spec.Commands(); err != nil {
return fmt.Errorf("invalid session spec")
}
@@ -104,25 +117,91 @@ func validateSessionSpec(spec session.Spec) error {
return nil
}
-func (s *sessionCommandState) storeStart(spec session.Spec) {
- s.mu.Lock()
- defer s.mu.Unlock()
+func (s *sessionCommandState) start(handler *ServerHandler, spec session.Spec) (uint64, error) {
+ commands, err := prepareSessionCommands(spec)
+ if err != nil {
+ return 0, err
+ }
+ s.mu.Lock()
+ if s.active {
+ s.mu.Unlock()
+ return 0, fmt.Errorf("session already started")
+ }
+ ctx, cancel := handler.newCommandContext(context.Background())
s.active = true
s.generation = 1
s.spec = spec
+ s.cancel = cancel
+ s.mu.Unlock()
+
+ if err := handler.dispatchSessionCommands(ctx, commands); err != nil {
+ cancel()
+ s.reset()
+ return 0, err
+ }
+
+ return 1, nil
}
-func (s *sessionCommandState) storeUpdate(spec session.Spec, generation uint64) {
- s.mu.Lock()
- defer s.mu.Unlock()
+func (s *sessionCommandState) update(handler *ServerHandler, spec session.Spec, generation uint64) (uint64, error) {
+ commands, err := prepareSessionCommands(spec)
+ if err != nil {
+ return 0, err
+ }
- s.active = true
+ s.mu.Lock()
+ if !s.active {
+ s.mu.Unlock()
+ return 0, fmt.Errorf("session not started")
+ }
+ oldCancel := s.cancel
+ ctx, cancel := handler.newCommandContext(context.Background())
if generation == 0 {
generation = s.generation + 1
}
+ s.active = true
s.generation = generation
s.spec = spec
+ s.cancel = cancel
+ s.mu.Unlock()
+
+ if oldCancel != nil {
+ oldCancel()
+ }
+
+ if err := handler.dispatchSessionCommands(ctx, commands); err != nil {
+ cancel()
+ s.reset()
+ return 0, err
+ }
+
+ return generation, nil
+}
+
+func prepareSessionCommands(spec session.Spec) ([]string, error) {
+ if spec.Query != "" {
+ return nil, fmt.Errorf("query sessions not supported yet")
+ }
+
+ commands, err := spec.Commands()
+ if err != nil {
+ return nil, fmt.Errorf("invalid session spec")
+ }
+
+ return commands, nil
+}
+
+func validateSessionOptions(raw string) error {
+ if strings.TrimSpace(raw) == "" {
+ return nil
+ }
+
+ if _, _, err := config.DeserializeOptions(strings.Split(raw, ":")); err != nil {
+ return fmt.Errorf("invalid session spec")
+ }
+
+ return nil
}
func (s *sessionCommandState) activeSession() bool {
@@ -131,6 +210,27 @@ func (s *sessionCommandState) activeSession() bool {
return s.active
}
-func (s *sessionCommandState) advertisedCapabilities() string {
- return protocol.HiddenCapabilitiesPrefix + protocol.CapabilityQueryUpdateV1
+func (s *sessionCommandState) keepAlive() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.active
+}
+
+func (s *sessionCommandState) reset() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.active = false
+ s.generation = 0
+ s.spec = session.Spec{}
+ s.cancel = nil
+}
+
+func (h *ServerHandler) dispatchSessionCommands(ctx context.Context, commands []string) error {
+ for _, command := range commands {
+ if err := h.handleRawCommand(ctx, command); err != nil {
+ return err
+ }
+ }
+ return nil
}