summaryrefslogtreecommitdiff
path: root/internal/ssh/server/publickeycallback.go
blob: c4624f49f6c6b48f1f3a8f436f7bd315d734f7f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package server

import (
	"fmt"
	"os"
	goUser "os/user"

	"github.com/mimecast/dtail/internal/config"
	"github.com/mimecast/dtail/internal/io/dlog"
	user "github.com/mimecast/dtail/internal/user/server"

	gossh "golang.org/x/crypto/ssh"
)

// PublicKeyCallback is for the server to check whether a public SSH key is
// authorized ot not.
func PublicKeyCallback(c gossh.ConnMetadata,
	offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) {

	user, err := user.New(c.User(), c.RemoteAddr().String())
	if err != nil {
		return nil, err
	}
	dlog.Server.Info(user, "Incoming authorization")

	if config.Server != nil && config.Server.AuthKeyEnabled {
		if permissions := authKeyStorePermissions(user.Name, offeredPubKey); permissions != nil {
			dlog.Server.Info(user, "Authorized by in-memory auth key store")
			return permissions, nil
		}
	}

	authorizedKeysFile, err := authorizedKeysFile(user)
	if err != nil {
		return nil, err
	}

	dlog.Server.Info(user, "Reading", authorizedKeysFile)
	authorizedKeysBytes, err := os.ReadFile(authorizedKeysFile)
	if err != nil {
		return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s",
			authorizedKeysFile, user, err.Error())
	}

	return verifyAuthorizedKeys(user, authorizedKeysBytes, offeredPubKey)
}

func verifyAuthorizedKeys(user *user.User, authorizedKeysBytes []byte,
	offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) {

	authorizedKeysMap := map[string]bool{}
	for len(authorizedKeysBytes) > 0 {
		authorizedPubKey, _, _, restBytes, err := gossh.ParseAuthorizedKey(authorizedKeysBytes)
		if err != nil {
			return nil, fmt.Errorf("unable to parse authorized keys bytes|%s|%s",
				user, err.Error())
		}
		authorizedKeysMap[string(authorizedPubKey.Marshal())] = true
		authorizedKeysBytes = restBytes
		dlog.Server.Debug(user, "Authorized public key fingerprint",
			gossh.FingerprintSHA256(authorizedPubKey))
	}

	dlog.Server.Debug(user, "Offered public key fingerprint", gossh.FingerprintSHA256(offeredPubKey))
	if authorizedKeysMap[string(offeredPubKey.Marshal())] {
		return permissionsFromPublicKey(offeredPubKey), nil
	}

	return nil, fmt.Errorf("%s|public key of user not authorized", user)
}

func authKeyStorePermissions(userName string, offeredPubKey gossh.PublicKey) *gossh.Permissions {
	if !authKeyStore.Has(userName, offeredPubKey) {
		return nil
	}

	return permissionsFromPublicKey(offeredPubKey)
}

func permissionsFromPublicKey(offeredPubKey gossh.PublicKey) *gossh.Permissions {
	return &gossh.Permissions{
		Extensions: map[string]string{"pubkey-fp": gossh.FingerprintSHA256(offeredPubKey)},
	}
}

func authorizedKeysFile(user *user.User) (string, error) {
	if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
		// In this case, we expect a pub key in the current directory.
		return "./id_rsa.pub", nil
	}

	cwd, err := os.Getwd()
	if err != nil {
		return "", err
	}

	// Check for cached version in the dserver directory.
	authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd,
		config.Common.CacheDir, user.Name)
	if _, err = os.Stat(authorizedKeysFile); err == nil {
		return authorizedKeysFile, nil
	}

	// As the last option, check the regular SSH path.
	osUser, err := goUser.Lookup(user.Name)
	if err != nil {
		return "", err
	}
	authorizedKeysFile = fmt.Sprintf("%s/.ssh/authorized_keys", osUser.HomeDir)
	if _, err = os.Stat(authorizedKeysFile); err == nil {
		return authorizedKeysFile, nil
	}

	return "", fmt.Errorf("unable to find a any authorized keys file")
}