diff options
Diffstat (limited to 'internal/clients/handlers/basehandler.go')
| -rw-r--r-- | internal/clients/handlers/basehandler.go | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 923b24a..8da4556 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "sort" + "strconv" "strings" "sync" "time" @@ -27,6 +28,15 @@ type baseHandler struct { capabilities map[string]struct{} capabilitiesCh chan struct{} capabilitiesOk sync.Once + + sessionAcks chan SessionAck +} + +// SessionAck is a parsed hidden acknowledgement for SESSION START/UPDATE requests. +type SessionAck struct { + Action string + Generation uint64 + Error string } func (h *baseHandler) String() string { @@ -177,6 +187,10 @@ func (h *baseHandler) handleHiddenMessage(message string) { switch { case strings.HasPrefix(message, protocol.HiddenCapabilitiesPrefix): h.handleCapabilitiesMessage(message) + case strings.HasPrefix(message, protocol.HiddenSessionStartOKPrefix), + strings.HasPrefix(message, protocol.HiddenSessionUpdateOKPrefix), + strings.HasPrefix(message, protocol.HiddenSessionErrorPrefix): + h.handleSessionAckMessage(message) case strings.HasPrefix(message, ".syn close connection"): go h.SendMessage(".ack close connection") h.Shutdown() @@ -237,6 +251,89 @@ func (h *baseHandler) WaitForCapabilities(timeout time.Duration) bool { } } +func (h *baseHandler) WaitForSessionAck(timeout time.Duration) (SessionAck, bool) { + if h.sessionAcks == nil { + return SessionAck{}, false + } + + if timeout <= 0 { + select { + case ack := <-h.sessionAcks: + return ack, true + default: + return SessionAck{}, false + } + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case ack := <-h.sessionAcks: + return ack, true + case <-h.Done(): + return SessionAck{}, false + case <-timer.C: + return SessionAck{}, false + } +} + func (h *baseHandler) Shutdown() { h.done.Shutdown() } + +func (h *baseHandler) handleSessionAckMessage(message string) { + ack, ok := parseSessionAckMessage(message) + if !ok { + dlog.Client.Warn(h.server, "Unable to parse session acknowledgement", message) + return + } + if h.sessionAcks == nil { + return + } + + select { + case h.sessionAcks <- ack: + case <-h.Done(): + default: + dlog.Client.Warn(h.server, "Dropping session acknowledgement because the queue is full", message) + } +} + +func parseSessionAckMessage(message string) (SessionAck, bool) { + payload := strings.TrimSpace(message) + if payload == "" { + return SessionAck{}, false + } + + switch { + case strings.HasPrefix(payload, protocol.HiddenSessionStartOKPrefix): + return parseSessionOKAck(strings.TrimPrefix(payload, protocol.HiddenSessionStartOKPrefix), "start") + case strings.HasPrefix(payload, protocol.HiddenSessionUpdateOKPrefix): + return parseSessionOKAck(strings.TrimPrefix(payload, protocol.HiddenSessionUpdateOKPrefix), "update") + case strings.HasPrefix(payload, protocol.HiddenSessionErrorPrefix): + return SessionAck{ + Action: "error", + Error: strings.TrimSpace(strings.TrimPrefix(payload, protocol.HiddenSessionErrorPrefix)), + }, true + default: + return SessionAck{}, false + } +} + +func parseSessionOKAck(payload string, action string) (SessionAck, bool) { + generationStr := strings.TrimSpace(payload) + if generationStr == "" { + return SessionAck{}, false + } + + generation, err := strconv.ParseUint(generationStr, 10, 64) + if err != nil { + return SessionAck{}, false + } + + return SessionAck{ + Action: action, + Generation: generation, + }, true +} |
