diff options
| -rw-r--r-- | internal/ssh/client/knownhostscallback.go | 40 | ||||
| -rw-r--r-- | internal/ssh/client/knownhostscallback_test.go | 157 |
2 files changed, 191 insertions, 6 deletions
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go index ac2ec92..174f6aa 100644 --- a/internal/ssh/client/knownhostscallback.go +++ b/internal/ssh/client/knownhostscallback.go @@ -11,6 +11,7 @@ import ( "time" "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/fs" "github.com/mimecast/dtail/internal/io/prompt" "golang.org/x/crypto/ssh" @@ -38,6 +39,7 @@ type unknownHost struct { // unknown hosts in a single batch to the known_hosts file. type KnownHostsCallback struct { knownHostsPath string + knownHostsFile fs.RootedPath unknownCh chan unknownHost throttleCh chan struct{} trustAllHostsCh chan struct{} @@ -51,11 +53,16 @@ var _ HostKeyCallback = (*KnownHostsCallback)(nil) func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) { - os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + knownHostsFile, err := fs.NewRootedPath(knownHostsPath) + if err != nil { + return nil, err + } + ensureKnownHostsFile(knownHostsFile) untrustedHosts := make(map[string]bool) c := KnownHostsCallback{ knownHostsPath: knownHostsPath, + knownHostsFile: knownHostsFile, unknownCh: make(chan unknownHost), trustAllHostsCh: make(chan struct{}), throttleCh: throttleCh, @@ -68,6 +75,20 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, return &c, nil } +func ensureKnownHostsFile(knownHostsFile fs.RootedPath) { + root, err := knownHostsFile.OpenRoot() + if err != nil { + return + } + defer root.Close() + + fd, err := root.OpenFile(knownHostsFile.Name(), os.O_RDONLY|os.O_CREATE, 0o666) + if err != nil { + return + } + fd.Close() +} + // Wrap the host key callback. func (c *KnownHostsCallback) Wrap() ssh.HostKeyCallback { return func(server string, remote net.Addr, key ssh.PublicKey) error { @@ -229,18 +250,25 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) { } func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error { + root, err := c.knownHostsFile.OpenRoot() + if err != nil { + return err + } + defer root.Close() + + tmpKnownHostsName := fmt.Sprintf("%s.tmp", c.knownHostsFile.Name()) tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath) cleanupTmp := func() { - if err := os.Remove(tmpKnownHostsPath); err != nil && !os.IsNotExist(err) { + if err := root.Remove(tmpKnownHostsName); err != nil && !os.IsNotExist(err) { dlog.Client.Debug("Unable to remove temporary known hosts file", tmpKnownHostsPath, err) } } - newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + newFd, err := root.OpenFile(tmpKnownHostsName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o600) if err != nil { return fmt.Errorf("open temp known hosts file %s: %w", tmpKnownHostsPath, err) } - if err := newFd.Chmod(0600); err != nil { + if err := newFd.Chmod(0o600); err != nil { newFd.Close() cleanupTmp() return fmt.Errorf("chmod temp known hosts file %s: %w", tmpKnownHostsPath, err) @@ -268,7 +296,7 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error { } // Read old known hosts file, to see which are old and new entries - oldFd, err := os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0600) + oldFd, err := root.OpenFile(c.knownHostsFile.Name(), os.O_RDONLY|os.O_CREATE, 0o600) if err != nil { newFd.Close() cleanupTmp() @@ -308,7 +336,7 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error { } // Now, replace old known hosts file - if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil { + if err := root.Rename(tmpKnownHostsName, c.knownHostsFile.Name()); err != nil { cleanupTmp() return fmt.Errorf("replace known_hosts file %s: %w", c.knownHostsPath, err) } diff --git a/internal/ssh/client/knownhostscallback_test.go b/internal/ssh/client/knownhostscallback_test.go new file mode 100644 index 0000000..596aea8 --- /dev/null +++ b/internal/ssh/client/knownhostscallback_test.go @@ -0,0 +1,157 @@ +package client + +import ( + "net" + "os" + "path/filepath" + "strings" + "testing" + + "golang.org/x/crypto/ssh/knownhosts" +) + +func TestTrustHostsAppendsDistinctExistingEntries(t *testing.T) { + knownHostsPath := filepath.Join(t.TempDir(), "known_hosts") + existingLine := knownhosts.Line([]string{"old.example:2222"}, &mockPublicKey{id: "old"}) + if err := os.WriteFile(knownHostsPath, []byte(existingLine+"\n"), 0o600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + callback := testKnownHostsCallback(t, knownHostsPath) + unknown := testUnknownHost("new.example:2222", "127.0.0.1:2222", "new") + + if err := callback.trustHosts([]unknownHost{unknown}); err != nil { + t.Fatalf("trustHosts failed: %v", err) + } + + got, err := os.ReadFile(knownHostsPath) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + + want := strings.Join([]string{ + unknown.hostLine, + unknown.ipLine, + existingLine, + "", + }, "\n") + if string(got) != want { + t.Fatalf("trustHosts wrote:\n%s\nwant:\n%s", got, want) + } + + if response := <-unknown.responseCh; response != trustHost { + t.Fatalf("unexpected trust response: %v", response) + } +} + +func TestTrustHostsReplacesExistingEntriesForSameHostAndIP(t *testing.T) { + knownHostsPath := filepath.Join(t.TempDir(), "known_hosts") + oldUnknown := testUnknownHost("replace.example:2222", "127.0.0.1:2222", "old") + keepLine := knownhosts.Line([]string{"keep.example:2222"}, &mockPublicKey{id: "keep"}) + initialContents := strings.Join([]string{ + oldUnknown.hostLine, + oldUnknown.ipLine, + keepLine, + "", + }, "\n") + if err := os.WriteFile(knownHostsPath, []byte(initialContents), 0o600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + callback := testKnownHostsCallback(t, knownHostsPath) + newUnknown := testUnknownHost("replace.example:2222", "127.0.0.1:2222", "new") + + if err := callback.trustHosts([]unknownHost{newUnknown}); err != nil { + t.Fatalf("trustHosts failed: %v", err) + } + + got, err := os.ReadFile(knownHostsPath) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + + want := strings.Join([]string{ + newUnknown.hostLine, + newUnknown.ipLine, + keepLine, + "", + }, "\n") + if string(got) != want { + t.Fatalf("trustHosts wrote:\n%s\nwant:\n%s", got, want) + } + + if response := <-newUnknown.responseCh; response != trustHost { + t.Fatalf("unexpected trust response: %v", response) + } +} + +func TestTrustHostsRejectsEscapingKnownHostsSymlink(t *testing.T) { + rootDir := filepath.Join(t.TempDir(), "ssh") + if err := os.MkdirAll(rootDir, 0o755); err != nil { + t.Fatalf("MkdirAll failed: %v", err) + } + + outsidePath := filepath.Join(filepath.Dir(rootDir), "outside_known_hosts") + if err := os.WriteFile(outsidePath, nil, 0o600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + knownHostsPath := filepath.Join(rootDir, "known_hosts") + if err := os.Symlink(filepath.Join("..", "outside_known_hosts"), knownHostsPath); err != nil { + t.Fatalf("Symlink failed: %v", err) + } + + callback := testKnownHostsCallback(t, knownHostsPath) + unknown := testUnknownHost("escape.example:2222", "127.0.0.1:2222", "new") + + if err := callback.trustHosts([]unknownHost{unknown}); err == nil { + t.Fatalf("trustHosts succeeded for escaping known_hosts symlink") + } +} + +func testKnownHostsCallback(t *testing.T, knownHostsPath string) *KnownHostsCallback { + t.Helper() + + throttleCh := make(chan struct{}, 1) + throttleCh <- struct{}{} + + callback, err := NewKnownHostsCallback(knownHostsPath, false, throttleCh) + if err != nil { + t.Fatalf("NewKnownHostsCallback failed: %v", err) + } + + knownHostsCallback, ok := callback.(*KnownHostsCallback) + if !ok { + t.Fatalf("unexpected callback type %T", callback) + } + + return knownHostsCallback +} + +func testUnknownHost(server, remoteAddr, keyID string) unknownHost { + key := &mockPublicKey{id: keyID} + remote := testTCPAddr(remoteAddr) + + return unknownHost{ + server: server, + remote: remote, + key: key, + hostLine: knownhosts.Line([]string{server}, key), + ipLine: knownhosts.Line([]string{remote.String()}, key), + responseCh: make(chan response, 1), + } +} + +func testTCPAddr(address string) *net.TCPAddr { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + panic(err) + } + + port, err := net.LookupPort("tcp", portStr) + if err != nil { + panic(err) + } + + return &net.TCPAddr{IP: net.ParseIP(host), Port: port} +} |
