summaryrefslogtreecommitdiff
path: root/internal/server/handlers/sessioncommand.go
blob: bc5f83e4f77d1b620afe89d1ff8f9696ba949041 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package handlers

import (
	"context"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"strconv"
	"strings"
	"sync"

	"github.com/mimecast/dtail/internal/lcontext"
	"github.com/mimecast/dtail/internal/omode"
	"github.com/mimecast/dtail/internal/protocol"
	"github.com/mimecast/dtail/internal/session"
)

const (
	sessionAckStartOKPrefix  = ".syn session start ok"
	sessionAckUpdateOKPrefix = ".syn session update ok"
	sessionAckErrorPrefix    = ".syn session err "
)

type sessionCommandState struct {
	mu         sync.Mutex
	active     bool
	generation uint64
	spec       session.Spec
}

func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) {
	defer commandFinished()

	action, generation, spec, err := parseSessionCommand(args, argc)
	if err != nil {
		h.send(h.serverMessages, sessionAckErrorPrefix+err.Error())
		return
	}

	switch action {
	case "START":
		h.sessionState.storeStart(spec)
		h.send(h.serverMessages, sessionAckStartOKPrefix)
	case "UPDATE":
		if !h.sessionState.activeSession() {
			h.send(h.serverMessages, sessionAckErrorPrefix+"session not started")
			return
		}
		h.sessionState.storeUpdate(spec, generation)
		h.send(h.serverMessages, sessionAckUpdateOKPrefix)
	default:
		h.send(h.serverMessages, sessionAckErrorPrefix+"unknown action")
	}
}

func parseSessionCommand(args []string, argc int) (action string, generation uint64, spec session.Spec, err error) {
	if argc < 3 {
		return "", 0, spec, fmt.Errorf("invalid SESSION command")
	}

	action = strings.ToUpper(strings.TrimSpace(args[1]))
	payloadIndex := 2
	if action == "UPDATE" && argc >= 4 {
		generation, err = strconv.ParseUint(args[2], 10, 64)
		if err != nil {
			return "", 0, spec, fmt.Errorf("invalid session generation")
		}
		payloadIndex = 3
	}

	payload, err := base64.StdEncoding.DecodeString(args[payloadIndex])
	if err != nil {
		return "", 0, spec, fmt.Errorf("invalid session payload")
	}
	if err := json.Unmarshal(payload, &spec); err != nil {
		return "", 0, spec, fmt.Errorf("invalid session spec")
	}
	if err := validateSessionSpec(spec); err != nil {
		return "", 0, spec, err
	}

	return action, generation, spec, nil
}

func validateSessionSpec(spec session.Spec) error {
	switch spec.Mode {
	case omode.TailClient, omode.CatClient, omode.GrepClient, omode.MapClient, omode.HealthClient:
	default:
		return fmt.Errorf("unsupported session mode")
	}

	if spec.Query != "" && spec.Mode != omode.MapClient && spec.Mode != omode.TailClient {
		return fmt.Errorf("query sessions require map or tail mode")
	}

	if spec.Query == "" && spec.Mode == omode.MapClient {
		return fmt.Errorf("missing session query")
	}

	if _, err := spec.Commands(); err != nil {
		return fmt.Errorf("invalid session spec")
	}

	return nil
}

func (s *sessionCommandState) storeStart(spec session.Spec) {
	s.mu.Lock()
	defer s.mu.Unlock()

	s.active = true
	s.generation = 1
	s.spec = spec
}

func (s *sessionCommandState) storeUpdate(spec session.Spec, generation uint64) {
	s.mu.Lock()
	defer s.mu.Unlock()

	s.active = true
	if generation == 0 {
		generation = s.generation + 1
	}
	s.generation = generation
	s.spec = spec
}

func (s *sessionCommandState) activeSession() bool {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.active
}

func (s *sessionCommandState) advertisedCapabilities() string {
	return protocol.HiddenCapabilitiesPrefix + protocol.CapabilityQueryUpdateV1
}