summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2020-02-28 17:29:22 +0000
committerPaul Buetow <pbuetow@mimecast.com>2020-02-28 17:29:22 +0000
commit7911b102171309dfc43bc2faccac6de9e490f175 (patch)
treeef489750dbc3e0c31402a88dcdadddd8533377ee /internal/server
parent3cdc86e20cbd311fb9c85cef63876a2f39e5e74d (diff)
parent1922e448e84e218cc39d4394e9b4becfa6f0a83d (diff)
merge master
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/mapcommand.go10
-rw-r--r--internal/server/scheduler.go22
-rw-r--r--internal/server/server.go36
3 files changed, 45 insertions, 23 deletions
diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go
index 10372da..c3e600e 100644
--- a/internal/server/handlers/mapcommand.go
+++ b/internal/server/handlers/mapcommand.go
@@ -15,18 +15,16 @@ type mapCommand struct {
// NewMapCommand returns a new server side mapreduce command.
func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) {
- mapCommand := mapCommand{
- server: serverHandler,
- }
+ m := mapCommand{server: serverHandler}
queryStr := strings.Join(args[1:], " ")
aggregate, err := server.NewAggregate(queryStr)
if err != nil {
- return mapCommand, nil, err
+ return m, nil, err
}
- mapCommand.aggregate = aggregate
- return mapCommand, aggregate, nil
+ m.aggregate = aggregate
+ return m, aggregate, nil
}
diff --git a/internal/server/scheduler.go b/internal/server/scheduler.go
index 41b0f41..db49d1b 100644
--- a/internal/server/scheduler.go
+++ b/internal/server/scheduler.go
@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
- "math/rand"
"os"
"strconv"
"strings"
@@ -21,23 +20,17 @@ const authLength = 64
const authCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$%^&*()_+[]"
type scheduler struct {
- authPayload string
}
func newScheduler() *scheduler {
- seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
-
- b := make([]byte, authLength)
- for i := range b {
- b[i] = authCharset[seededRand.Intn(len(authCharset))]
- }
-
- return &scheduler{
- authPayload: string(b),
- }
+ return &scheduler{}
}
func (s *scheduler) start(ctx context.Context) {
+ // First run after just 10s!
+ time.Sleep(time.Second * 10)
+ s.runJobs(ctx)
+
for {
select {
case <-time.After(time.Minute):
@@ -75,7 +68,7 @@ func (s *scheduler) runJobs(ctx context.Context) {
continue
}
- servers := scheduled.Servers
+ servers := strings.Join(scheduled.Servers, ",")
if servers == "" {
servers = config.Server.SSHBindAddress
}
@@ -88,7 +81,8 @@ func (s *scheduler) runJobs(ctx context.Context) {
Mode: omode.MapClient,
UserName: config.ScheduleUser,
}
- args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(s.authPayload))
+
+ args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(scheduled.Name))
tmpOutfile := fmt.Sprintf("%s.tmp", outfile)
query := fmt.Sprintf("%s outfile %s", scheduled.Query, tmpOutfile)
diff --git a/internal/server/server.go b/internal/server/server.go
index 5ec46e7..eb0cdd7 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
+ "strings"
"time"
"github.com/mimecast/dtail/internal/config"
@@ -238,16 +239,45 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
func (s *Server) backgroundUserCallback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
user := user.New(c.User(), c.RemoteAddr().String())
+ authInfo := string(authPayload)
- if user.Name == config.ControlUser && string(authPayload) == config.ControlUser {
+ if user.Name == config.ControlUser && authInfo == config.ControlUser {
logger.Debug(user, "Granting permissions to control user")
return nil, nil
}
- if user.Name == config.ScheduleUser && string(authPayload) == s.sched.authPayload {
- logger.Debug(user, "Granting permissions to schedule user")
+ if user.Name == config.ScheduleUser && s.schedueleUserCanHaveSSHSession(c.RemoteAddr().String(), user, authInfo) {
+ logger.Debug(user, "Granting SSH connection to schedule user")
return nil, nil
}
return nil, fmt.Errorf("user %s not authorized", user)
}
+
+func (s *Server) schedueleUserCanHaveSSHSession(addr string, user *user.User, jobName string) bool {
+ logger.Debug("schedueleUserCanHaveSSHSession", user, jobName)
+ splitted := strings.Split(addr, ":")
+ ip := splitted[0]
+
+ for _, job := range config.Server.Schedule {
+ if job.Name != jobName {
+ continue
+ }
+ for _, myAddr := range job.AllowFrom {
+ myIps, err := net.LookupIP(myAddr)
+ if err != nil {
+ logger.Error(user, myAddr, err)
+ continue
+ }
+
+ for _, myIp := range myIps {
+ logger.Debug("schedueleUserCanHaveSSHSession", "Comparing IP addresses", ip, myIp.String())
+ if ip == myIp.String() {
+ return true
+ }
+ }
+ }
+ }
+
+ return false
+}