summaryrefslogtreecommitdiff
path: root/internal/server/handlers/authkeycommand_test.go
blob: bb9488b0eccf9dc5c1ec5a4977d836ea5190a7b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package handlers

import (
	"context"
	"crypto/ed25519"
	"encoding/base64"
	"testing"
	"time"

	"github.com/mimecast/dtail/internal"
	"github.com/mimecast/dtail/internal/config"
	"github.com/mimecast/dtail/internal/lcontext"
	sshserver "github.com/mimecast/dtail/internal/ssh/server"
	userserver "github.com/mimecast/dtail/internal/user/server"

	gossh "golang.org/x/crypto/ssh"
)

func TestHandleAuthKeyCommandSuccess(t *testing.T) {
	handler := newAuthKeyTestHandler("authkey-success-user", true)
	key := handlerTestPublicKey(t, 31)
	keyArg := base64.StdEncoding.EncodeToString(key.Marshal())

	commandFinished := false
	handler.handleAuthKeyCommand(context.Background(), lcontext.LContext{}, 2,
		[]string{"AUTHKEY", keyArg}, func() {
			commandFinished = true
		})

	if !commandFinished {
		t.Fatalf("Expected commandFinished callback to be called")
	}
	if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY OK\n" {
		t.Fatalf("Unexpected response: %q", message)
	}
	if !sshserver.ServerAuthKeyStore().Has(handler.user.Name, key) {
		t.Fatalf("Expected key to be stored for user")
	}

	sshserver.ServerAuthKeyStore().Remove(handler.user.Name, key)
}

func TestHandleAuthKeyCommandFeatureDisabled(t *testing.T) {
	handler := newAuthKeyTestHandler("authkey-disabled-user", false)
	key := handlerTestPublicKey(t, 32)
	keyArg := base64.StdEncoding.EncodeToString(key.Marshal())

	handler.handleAuthKeyCommand(context.Background(), lcontext.LContext{}, 2,
		[]string{"AUTHKEY", keyArg}, func() {})

	if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY ERR feature disabled\n" {
		t.Fatalf("Unexpected response: %q", message)
	}
	if sshserver.ServerAuthKeyStore().Has(handler.user.Name, key) {
		t.Fatalf("Expected no key to be stored while feature is disabled")
	}
}

func TestHandleAuthKeyCommandInvalidPayload(t *testing.T) {
	handler := newAuthKeyTestHandler("authkey-invalid-user", true)

	handler.handleAuthKeyCommand(context.Background(), lcontext.LContext{}, 2,
		[]string{"AUTHKEY", "not-base64"}, func() {})

	if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY ERR invalid base64\n" {
		t.Fatalf("Unexpected response for invalid base64: %q", message)
	}

	validButNonSSH := base64.StdEncoding.EncodeToString([]byte("not-an-ssh-key"))
	handler.handleAuthKeyCommand(context.Background(), lcontext.LContext{}, 2,
		[]string{"AUTHKEY", validButNonSSH}, func() {})
	if message := readServerMessage(t, handler.serverMessages); message != "AUTHKEY ERR invalid public key\n" {
		t.Fatalf("Unexpected response for invalid key bytes: %q", message)
	}
}

func newAuthKeyTestHandler(userName string, authKeyEnabled bool) *ServerHandler {
	return &ServerHandler{
		baseHandler: baseHandler{
			done:           internal.NewDone(),
			serverMessages: make(chan string, 4),
			user:           &userserver.User{Name: userName},
		},
		serverCfg: &config.ServerConfig{
			AuthKeyEnabled: authKeyEnabled,
		},
	}
}

func handlerTestPublicKey(t *testing.T, seedByte byte) gossh.PublicKey {
	t.Helper()

	seed := make([]byte, ed25519.SeedSize)
	for i := range seed {
		seed[i] = seedByte
	}

	privateKey := ed25519.NewKeyFromSeed(seed)
	publicKey, err := gossh.NewPublicKey(privateKey.Public())
	if err != nil {
		t.Fatalf("Unable to build ssh public key: %s", err.Error())
	}

	return publicKey
}

func readServerMessage(t *testing.T, messages <-chan string) string {
	t.Helper()

	select {
	case message := <-messages:
		return message
	case <-time.After(time.Second):
		t.Fatalf("Timed out waiting for server message")
		return ""
	}
}