1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
|
package handlers
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"strconv"
"strings"
"sync"
"github.com/mimecast/dtail/internal/lcontext"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/protocol"
"github.com/mimecast/dtail/internal/session"
)
const (
sessionAckStartOKPrefix = ".syn session start ok"
sessionAckUpdateOKPrefix = ".syn session update ok"
sessionAckErrorPrefix = ".syn session err "
)
type sessionCommandState struct {
mu sync.Mutex
active bool
generation uint64
spec session.Spec
}
func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) {
defer commandFinished()
action, generation, spec, err := parseSessionCommand(args, argc)
if err != nil {
h.send(h.serverMessages, sessionAckErrorPrefix+err.Error())
return
}
switch action {
case "START":
h.sessionState.storeStart(spec)
h.send(h.serverMessages, sessionAckStartOKPrefix)
case "UPDATE":
if !h.sessionState.activeSession() {
h.send(h.serverMessages, sessionAckErrorPrefix+"session not started")
return
}
h.sessionState.storeUpdate(spec, generation)
h.send(h.serverMessages, sessionAckUpdateOKPrefix)
default:
h.send(h.serverMessages, sessionAckErrorPrefix+"unknown action")
}
}
func parseSessionCommand(args []string, argc int) (action string, generation uint64, spec session.Spec, err error) {
if argc < 3 {
return "", 0, spec, fmt.Errorf("invalid SESSION command")
}
action = strings.ToUpper(strings.TrimSpace(args[1]))
payloadIndex := 2
if action == "UPDATE" && argc >= 4 {
generation, err = strconv.ParseUint(args[2], 10, 64)
if err != nil {
return "", 0, spec, fmt.Errorf("invalid session generation")
}
payloadIndex = 3
}
payload, err := base64.StdEncoding.DecodeString(args[payloadIndex])
if err != nil {
return "", 0, spec, fmt.Errorf("invalid session payload")
}
if err := json.Unmarshal(payload, &spec); err != nil {
return "", 0, spec, fmt.Errorf("invalid session spec")
}
if err := validateSessionSpec(spec); err != nil {
return "", 0, spec, err
}
return action, generation, spec, nil
}
func validateSessionSpec(spec session.Spec) error {
switch spec.Mode {
case omode.TailClient, omode.CatClient, omode.GrepClient, omode.MapClient, omode.HealthClient:
default:
return fmt.Errorf("unsupported session mode")
}
if spec.Query != "" && spec.Mode != omode.MapClient && spec.Mode != omode.TailClient {
return fmt.Errorf("query sessions require map or tail mode")
}
if spec.Query == "" && spec.Mode == omode.MapClient {
return fmt.Errorf("missing session query")
}
if _, err := spec.Commands(); err != nil {
return fmt.Errorf("invalid session spec")
}
return nil
}
func (s *sessionCommandState) storeStart(spec session.Spec) {
s.mu.Lock()
defer s.mu.Unlock()
s.active = true
s.generation = 1
s.spec = spec
}
func (s *sessionCommandState) storeUpdate(spec session.Spec, generation uint64) {
s.mu.Lock()
defer s.mu.Unlock()
s.active = true
if generation == 0 {
generation = s.generation + 1
}
s.generation = generation
s.spec = spec
}
func (s *sessionCommandState) activeSession() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.active
}
func (s *sessionCommandState) advertisedCapabilities() string {
return protocol.HiddenCapabilitiesPrefix + protocol.CapabilityQueryUpdateV1
}
|