summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/ssh/client/knownhostscallback.go40
-rw-r--r--internal/ssh/client/knownhostscallback_test.go157
2 files changed, 191 insertions, 6 deletions
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go
index ac2ec92..174f6aa 100644
--- a/internal/ssh/client/knownhostscallback.go
+++ b/internal/ssh/client/knownhostscallback.go
@@ -11,6 +11,7 @@ import (
"time"
"github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/fs"
"github.com/mimecast/dtail/internal/io/prompt"
"golang.org/x/crypto/ssh"
@@ -38,6 +39,7 @@ type unknownHost struct {
// unknown hosts in a single batch to the known_hosts file.
type KnownHostsCallback struct {
knownHostsPath string
+ knownHostsFile fs.RootedPath
unknownCh chan unknownHost
throttleCh chan struct{}
trustAllHostsCh chan struct{}
@@ -51,11 +53,16 @@ var _ HostKeyCallback = (*KnownHostsCallback)(nil)
func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
throttleCh chan struct{}) (HostKeyCallback, error) {
- os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
+ knownHostsFile, err := fs.NewRootedPath(knownHostsPath)
+ if err != nil {
+ return nil, err
+ }
+ ensureKnownHostsFile(knownHostsFile)
untrustedHosts := make(map[string]bool)
c := KnownHostsCallback{
knownHostsPath: knownHostsPath,
+ knownHostsFile: knownHostsFile,
unknownCh: make(chan unknownHost),
trustAllHostsCh: make(chan struct{}),
throttleCh: throttleCh,
@@ -68,6 +75,20 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
return &c, nil
}
+func ensureKnownHostsFile(knownHostsFile fs.RootedPath) {
+ root, err := knownHostsFile.OpenRoot()
+ if err != nil {
+ return
+ }
+ defer root.Close()
+
+ fd, err := root.OpenFile(knownHostsFile.Name(), os.O_RDONLY|os.O_CREATE, 0o666)
+ if err != nil {
+ return
+ }
+ fd.Close()
+}
+
// Wrap the host key callback.
func (c *KnownHostsCallback) Wrap() ssh.HostKeyCallback {
return func(server string, remote net.Addr, key ssh.PublicKey) error {
@@ -229,18 +250,25 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
}
func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error {
+ root, err := c.knownHostsFile.OpenRoot()
+ if err != nil {
+ return err
+ }
+ defer root.Close()
+
+ tmpKnownHostsName := fmt.Sprintf("%s.tmp", c.knownHostsFile.Name())
tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath)
cleanupTmp := func() {
- if err := os.Remove(tmpKnownHostsPath); err != nil && !os.IsNotExist(err) {
+ if err := root.Remove(tmpKnownHostsName); err != nil && !os.IsNotExist(err) {
dlog.Client.Debug("Unable to remove temporary known hosts file", tmpKnownHostsPath, err)
}
}
- newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+ newFd, err := root.OpenFile(tmpKnownHostsName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o600)
if err != nil {
return fmt.Errorf("open temp known hosts file %s: %w", tmpKnownHostsPath, err)
}
- if err := newFd.Chmod(0600); err != nil {
+ if err := newFd.Chmod(0o600); err != nil {
newFd.Close()
cleanupTmp()
return fmt.Errorf("chmod temp known hosts file %s: %w", tmpKnownHostsPath, err)
@@ -268,7 +296,7 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error {
}
// Read old known hosts file, to see which are old and new entries
- oldFd, err := os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0600)
+ oldFd, err := root.OpenFile(c.knownHostsFile.Name(), os.O_RDONLY|os.O_CREATE, 0o600)
if err != nil {
newFd.Close()
cleanupTmp()
@@ -308,7 +336,7 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error {
}
// Now, replace old known hosts file
- if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil {
+ if err := root.Rename(tmpKnownHostsName, c.knownHostsFile.Name()); err != nil {
cleanupTmp()
return fmt.Errorf("replace known_hosts file %s: %w", c.knownHostsPath, err)
}
diff --git a/internal/ssh/client/knownhostscallback_test.go b/internal/ssh/client/knownhostscallback_test.go
new file mode 100644
index 0000000..596aea8
--- /dev/null
+++ b/internal/ssh/client/knownhostscallback_test.go
@@ -0,0 +1,157 @@
+package client
+
+import (
+ "net"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "golang.org/x/crypto/ssh/knownhosts"
+)
+
+func TestTrustHostsAppendsDistinctExistingEntries(t *testing.T) {
+ knownHostsPath := filepath.Join(t.TempDir(), "known_hosts")
+ existingLine := knownhosts.Line([]string{"old.example:2222"}, &mockPublicKey{id: "old"})
+ if err := os.WriteFile(knownHostsPath, []byte(existingLine+"\n"), 0o600); err != nil {
+ t.Fatalf("WriteFile failed: %v", err)
+ }
+
+ callback := testKnownHostsCallback(t, knownHostsPath)
+ unknown := testUnknownHost("new.example:2222", "127.0.0.1:2222", "new")
+
+ if err := callback.trustHosts([]unknownHost{unknown}); err != nil {
+ t.Fatalf("trustHosts failed: %v", err)
+ }
+
+ got, err := os.ReadFile(knownHostsPath)
+ if err != nil {
+ t.Fatalf("ReadFile failed: %v", err)
+ }
+
+ want := strings.Join([]string{
+ unknown.hostLine,
+ unknown.ipLine,
+ existingLine,
+ "",
+ }, "\n")
+ if string(got) != want {
+ t.Fatalf("trustHosts wrote:\n%s\nwant:\n%s", got, want)
+ }
+
+ if response := <-unknown.responseCh; response != trustHost {
+ t.Fatalf("unexpected trust response: %v", response)
+ }
+}
+
+func TestTrustHostsReplacesExistingEntriesForSameHostAndIP(t *testing.T) {
+ knownHostsPath := filepath.Join(t.TempDir(), "known_hosts")
+ oldUnknown := testUnknownHost("replace.example:2222", "127.0.0.1:2222", "old")
+ keepLine := knownhosts.Line([]string{"keep.example:2222"}, &mockPublicKey{id: "keep"})
+ initialContents := strings.Join([]string{
+ oldUnknown.hostLine,
+ oldUnknown.ipLine,
+ keepLine,
+ "",
+ }, "\n")
+ if err := os.WriteFile(knownHostsPath, []byte(initialContents), 0o600); err != nil {
+ t.Fatalf("WriteFile failed: %v", err)
+ }
+
+ callback := testKnownHostsCallback(t, knownHostsPath)
+ newUnknown := testUnknownHost("replace.example:2222", "127.0.0.1:2222", "new")
+
+ if err := callback.trustHosts([]unknownHost{newUnknown}); err != nil {
+ t.Fatalf("trustHosts failed: %v", err)
+ }
+
+ got, err := os.ReadFile(knownHostsPath)
+ if err != nil {
+ t.Fatalf("ReadFile failed: %v", err)
+ }
+
+ want := strings.Join([]string{
+ newUnknown.hostLine,
+ newUnknown.ipLine,
+ keepLine,
+ "",
+ }, "\n")
+ if string(got) != want {
+ t.Fatalf("trustHosts wrote:\n%s\nwant:\n%s", got, want)
+ }
+
+ if response := <-newUnknown.responseCh; response != trustHost {
+ t.Fatalf("unexpected trust response: %v", response)
+ }
+}
+
+func TestTrustHostsRejectsEscapingKnownHostsSymlink(t *testing.T) {
+ rootDir := filepath.Join(t.TempDir(), "ssh")
+ if err := os.MkdirAll(rootDir, 0o755); err != nil {
+ t.Fatalf("MkdirAll failed: %v", err)
+ }
+
+ outsidePath := filepath.Join(filepath.Dir(rootDir), "outside_known_hosts")
+ if err := os.WriteFile(outsidePath, nil, 0o600); err != nil {
+ t.Fatalf("WriteFile failed: %v", err)
+ }
+
+ knownHostsPath := filepath.Join(rootDir, "known_hosts")
+ if err := os.Symlink(filepath.Join("..", "outside_known_hosts"), knownHostsPath); err != nil {
+ t.Fatalf("Symlink failed: %v", err)
+ }
+
+ callback := testKnownHostsCallback(t, knownHostsPath)
+ unknown := testUnknownHost("escape.example:2222", "127.0.0.1:2222", "new")
+
+ if err := callback.trustHosts([]unknownHost{unknown}); err == nil {
+ t.Fatalf("trustHosts succeeded for escaping known_hosts symlink")
+ }
+}
+
+func testKnownHostsCallback(t *testing.T, knownHostsPath string) *KnownHostsCallback {
+ t.Helper()
+
+ throttleCh := make(chan struct{}, 1)
+ throttleCh <- struct{}{}
+
+ callback, err := NewKnownHostsCallback(knownHostsPath, false, throttleCh)
+ if err != nil {
+ t.Fatalf("NewKnownHostsCallback failed: %v", err)
+ }
+
+ knownHostsCallback, ok := callback.(*KnownHostsCallback)
+ if !ok {
+ t.Fatalf("unexpected callback type %T", callback)
+ }
+
+ return knownHostsCallback
+}
+
+func testUnknownHost(server, remoteAddr, keyID string) unknownHost {
+ key := &mockPublicKey{id: keyID}
+ remote := testTCPAddr(remoteAddr)
+
+ return unknownHost{
+ server: server,
+ remote: remote,
+ key: key,
+ hostLine: knownhosts.Line([]string{server}, key),
+ ipLine: knownhosts.Line([]string{remote.String()}, key),
+ responseCh: make(chan response, 1),
+ }
+}
+
+func testTCPAddr(address string) *net.TCPAddr {
+ host, portStr, err := net.SplitHostPort(address)
+ if err != nil {
+ panic(err)
+ }
+
+ port, err := net.LookupPort("tcp", portStr)
+ if err != nil {
+ panic(err)
+ }
+
+ return &net.TCPAddr{IP: net.ParseIP(host), Port: port}
+}