diff options
Diffstat (limited to 'internal/clients/connectors/serverconnection.go')
| -rw-r--r-- | internal/clients/connectors/serverconnection.go | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 649fe30..1136bf9 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -12,13 +12,23 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/ssh/client" "golang.org/x/crypto/ssh" ) +// SSHSettings provides the connection settings needed by ServerConnection. +type SSHSettings interface { + SSHPort() int + SSHConnectTimeout() time.Duration +} + +const ( + defaultSSHConnectTimeout = 2 * time.Second + defaultSSHPort = 2222 +) + // ServerConnection represents a connection to a single remote dtail server via // SSH protocol. type ServerConnection struct { @@ -43,12 +53,20 @@ var _ Connector = (*ServerConnection)(nil) func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, handler handlers.Handler, commands []string, authKeyPath string, - authKeyDisabled bool) *ServerConnection { + authKeyDisabled bool, settings SSHSettings) *ServerConnection { dlog.Client.Debug(server, "Creating new connection", server, handler, commands) - sshConnectTimeout := time.Duration(config.Common.SSHConnectTimeoutMs) * time.Millisecond + sshConnectTimeout := defaultSSHConnectTimeout + defaultPort := defaultSSHPort + if settings != nil { + sshConnectTimeout = settings.SSHConnectTimeout() + defaultPort = settings.SSHPort() + } if sshConnectTimeout <= 0 { - sshConnectTimeout = 2 * time.Second + sshConnectTimeout = defaultSSHConnectTimeout + } + if defaultPort <= 0 { + defaultPort = defaultSSHPort } c := ServerConnection{ @@ -66,7 +84,7 @@ func NewServerConnection(server string, userName string, }, } - c.initServerPort() + c.initServerPort(defaultPort) return &c } @@ -77,11 +95,11 @@ func (c *ServerConnection) Server() string { return c.server } func (c *ServerConnection) Handler() handlers.Handler { return c.handler } // Attempt to parse the server port address from the provided server FQDN. -func (c *ServerConnection) initServerPort() { +func (c *ServerConnection) initServerPort(defaultPort int) { parts := strings.Split(c.server, ":") if len(parts) == 1 { c.hostname = c.server - c.port = config.Common.SSHPort + c.port = defaultPort return } |
