summaryrefslogtreecommitdiff
path: root/internal/clients/connectors/serverconnection.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients/connectors/serverconnection.go')
-rw-r--r--internal/clients/connectors/serverconnection.go32
1 files changed, 25 insertions, 7 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 649fe30..1136bf9 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -12,13 +12,23 @@ import (
"time"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/ssh/client"
"golang.org/x/crypto/ssh"
)
+// SSHSettings provides the connection settings needed by ServerConnection.
+type SSHSettings interface {
+ SSHPort() int
+ SSHConnectTimeout() time.Duration
+}
+
+const (
+ defaultSSHConnectTimeout = 2 * time.Second
+ defaultSSHPort = 2222
+)
+
// ServerConnection represents a connection to a single remote dtail server via
// SSH protocol.
type ServerConnection struct {
@@ -43,12 +53,20 @@ var _ Connector = (*ServerConnection)(nil)
func NewServerConnection(server string, userName string,
authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback,
handler handlers.Handler, commands []string, authKeyPath string,
- authKeyDisabled bool) *ServerConnection {
+ authKeyDisabled bool, settings SSHSettings) *ServerConnection {
dlog.Client.Debug(server, "Creating new connection", server, handler, commands)
- sshConnectTimeout := time.Duration(config.Common.SSHConnectTimeoutMs) * time.Millisecond
+ sshConnectTimeout := defaultSSHConnectTimeout
+ defaultPort := defaultSSHPort
+ if settings != nil {
+ sshConnectTimeout = settings.SSHConnectTimeout()
+ defaultPort = settings.SSHPort()
+ }
if sshConnectTimeout <= 0 {
- sshConnectTimeout = 2 * time.Second
+ sshConnectTimeout = defaultSSHConnectTimeout
+ }
+ if defaultPort <= 0 {
+ defaultPort = defaultSSHPort
}
c := ServerConnection{
@@ -66,7 +84,7 @@ func NewServerConnection(server string, userName string,
},
}
- c.initServerPort()
+ c.initServerPort(defaultPort)
return &c
}
@@ -77,11 +95,11 @@ func (c *ServerConnection) Server() string { return c.server }
func (c *ServerConnection) Handler() handlers.Handler { return c.handler }
// Attempt to parse the server port address from the provided server FQDN.
-func (c *ServerConnection) initServerPort() {
+func (c *ServerConnection) initServerPort(defaultPort int) {
parts := strings.Split(c.server, ":")
if len(parts) == 1 {
c.hostname = c.server
- c.port = config.Common.SSHPort
+ c.port = defaultPort
return
}