diff options
Diffstat (limited to 'internal/ssh/client/knownhostscallback.go')
| -rw-r--r-- | internal/ssh/client/knownhostscallback.go | 40 |
1 files changed, 34 insertions, 6 deletions
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go index ac2ec92..174f6aa 100644 --- a/internal/ssh/client/knownhostscallback.go +++ b/internal/ssh/client/knownhostscallback.go @@ -11,6 +11,7 @@ import ( "time" "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/fs" "github.com/mimecast/dtail/internal/io/prompt" "golang.org/x/crypto/ssh" @@ -38,6 +39,7 @@ type unknownHost struct { // unknown hosts in a single batch to the known_hosts file. type KnownHostsCallback struct { knownHostsPath string + knownHostsFile fs.RootedPath unknownCh chan unknownHost throttleCh chan struct{} trustAllHostsCh chan struct{} @@ -51,11 +53,16 @@ var _ HostKeyCallback = (*KnownHostsCallback)(nil) func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) { - os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + knownHostsFile, err := fs.NewRootedPath(knownHostsPath) + if err != nil { + return nil, err + } + ensureKnownHostsFile(knownHostsFile) untrustedHosts := make(map[string]bool) c := KnownHostsCallback{ knownHostsPath: knownHostsPath, + knownHostsFile: knownHostsFile, unknownCh: make(chan unknownHost), trustAllHostsCh: make(chan struct{}), throttleCh: throttleCh, @@ -68,6 +75,20 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, return &c, nil } +func ensureKnownHostsFile(knownHostsFile fs.RootedPath) { + root, err := knownHostsFile.OpenRoot() + if err != nil { + return + } + defer root.Close() + + fd, err := root.OpenFile(knownHostsFile.Name(), os.O_RDONLY|os.O_CREATE, 0o666) + if err != nil { + return + } + fd.Close() +} + // Wrap the host key callback. func (c *KnownHostsCallback) Wrap() ssh.HostKeyCallback { return func(server string, remote net.Addr, key ssh.PublicKey) error { @@ -229,18 +250,25 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) { } func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error { + root, err := c.knownHostsFile.OpenRoot() + if err != nil { + return err + } + defer root.Close() + + tmpKnownHostsName := fmt.Sprintf("%s.tmp", c.knownHostsFile.Name()) tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath) cleanupTmp := func() { - if err := os.Remove(tmpKnownHostsPath); err != nil && !os.IsNotExist(err) { + if err := root.Remove(tmpKnownHostsName); err != nil && !os.IsNotExist(err) { dlog.Client.Debug("Unable to remove temporary known hosts file", tmpKnownHostsPath, err) } } - newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + newFd, err := root.OpenFile(tmpKnownHostsName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o600) if err != nil { return fmt.Errorf("open temp known hosts file %s: %w", tmpKnownHostsPath, err) } - if err := newFd.Chmod(0600); err != nil { + if err := newFd.Chmod(0o600); err != nil { newFd.Close() cleanupTmp() return fmt.Errorf("chmod temp known hosts file %s: %w", tmpKnownHostsPath, err) @@ -268,7 +296,7 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error { } // Read old known hosts file, to see which are old and new entries - oldFd, err := os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0600) + oldFd, err := root.OpenFile(c.knownHostsFile.Name(), os.O_RDONLY|os.O_CREATE, 0o600) if err != nil { newFd.Close() cleanupTmp() @@ -308,7 +336,7 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error { } // Now, replace old known hosts file - if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil { + if err := root.Rename(tmpKnownHostsName, c.knownHostsFile.Name()); err != nil { cleanupTmp() return fmt.Errorf("replace known_hosts file %s: %w", c.knownHostsPath, err) } |
