summaryrefslogtreecommitdiff
path: root/internal/ssh/client/knownhostscallback.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ssh/client/knownhostscallback.go')
-rw-r--r--internal/ssh/client/knownhostscallback.go40
1 files changed, 34 insertions, 6 deletions
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go
index ac2ec92..174f6aa 100644
--- a/internal/ssh/client/knownhostscallback.go
+++ b/internal/ssh/client/knownhostscallback.go
@@ -11,6 +11,7 @@ import (
"time"
"github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/fs"
"github.com/mimecast/dtail/internal/io/prompt"
"golang.org/x/crypto/ssh"
@@ -38,6 +39,7 @@ type unknownHost struct {
// unknown hosts in a single batch to the known_hosts file.
type KnownHostsCallback struct {
knownHostsPath string
+ knownHostsFile fs.RootedPath
unknownCh chan unknownHost
throttleCh chan struct{}
trustAllHostsCh chan struct{}
@@ -51,11 +53,16 @@ var _ HostKeyCallback = (*KnownHostsCallback)(nil)
func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
throttleCh chan struct{}) (HostKeyCallback, error) {
- os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
+ knownHostsFile, err := fs.NewRootedPath(knownHostsPath)
+ if err != nil {
+ return nil, err
+ }
+ ensureKnownHostsFile(knownHostsFile)
untrustedHosts := make(map[string]bool)
c := KnownHostsCallback{
knownHostsPath: knownHostsPath,
+ knownHostsFile: knownHostsFile,
unknownCh: make(chan unknownHost),
trustAllHostsCh: make(chan struct{}),
throttleCh: throttleCh,
@@ -68,6 +75,20 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
return &c, nil
}
+func ensureKnownHostsFile(knownHostsFile fs.RootedPath) {
+ root, err := knownHostsFile.OpenRoot()
+ if err != nil {
+ return
+ }
+ defer root.Close()
+
+ fd, err := root.OpenFile(knownHostsFile.Name(), os.O_RDONLY|os.O_CREATE, 0o666)
+ if err != nil {
+ return
+ }
+ fd.Close()
+}
+
// Wrap the host key callback.
func (c *KnownHostsCallback) Wrap() ssh.HostKeyCallback {
return func(server string, remote net.Addr, key ssh.PublicKey) error {
@@ -229,18 +250,25 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
}
func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error {
+ root, err := c.knownHostsFile.OpenRoot()
+ if err != nil {
+ return err
+ }
+ defer root.Close()
+
+ tmpKnownHostsName := fmt.Sprintf("%s.tmp", c.knownHostsFile.Name())
tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath)
cleanupTmp := func() {
- if err := os.Remove(tmpKnownHostsPath); err != nil && !os.IsNotExist(err) {
+ if err := root.Remove(tmpKnownHostsName); err != nil && !os.IsNotExist(err) {
dlog.Client.Debug("Unable to remove temporary known hosts file", tmpKnownHostsPath, err)
}
}
- newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+ newFd, err := root.OpenFile(tmpKnownHostsName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o600)
if err != nil {
return fmt.Errorf("open temp known hosts file %s: %w", tmpKnownHostsPath, err)
}
- if err := newFd.Chmod(0600); err != nil {
+ if err := newFd.Chmod(0o600); err != nil {
newFd.Close()
cleanupTmp()
return fmt.Errorf("chmod temp known hosts file %s: %w", tmpKnownHostsPath, err)
@@ -268,7 +296,7 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error {
}
// Read old known hosts file, to see which are old and new entries
- oldFd, err := os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0600)
+ oldFd, err := root.OpenFile(c.knownHostsFile.Name(), os.O_RDONLY|os.O_CREATE, 0o600)
if err != nil {
newFd.Close()
cleanupTmp()
@@ -308,7 +336,7 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error {
}
// Now, replace old known hosts file
- if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil {
+ if err := root.Rename(tmpKnownHostsName, c.knownHostsFile.Name()); err != nil {
cleanupTmp()
return fmt.Errorf("replace known_hosts file %s: %w", c.knownHostsPath, err)
}