summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile4
-rw-r--r--cmd/dtail/main.go7
-rw-r--r--go.mod7
-rw-r--r--go.sum6
-rw-r--r--internal/clients/handlers/maprhandler.go7
-rw-r--r--internal/clients/maprclient.go10
-rw-r--r--internal/config/config.go9
-rw-r--r--internal/config/server.go6
-rw-r--r--internal/mapr/aggregateset.go14
-rw-r--r--internal/mapr/client/aggregate.go10
-rw-r--r--internal/mapr/fieldtypes.go29
-rw-r--r--internal/mapr/funcs/function.go66
-rw-r--r--internal/mapr/funcs/function_test.go45
-rw-r--r--internal/mapr/funcs/maskdigits.go14
-rw-r--r--internal/mapr/funcs/md5sum.go12
-rw-r--r--internal/mapr/logformat/default.go1
-rw-r--r--internal/mapr/logformat/default_test.go2
-rw-r--r--internal/mapr/logformat/parser.go3
-rw-r--r--internal/mapr/query.go103
-rw-r--r--internal/mapr/query_test.go40
-rw-r--r--internal/mapr/server/aggregate.go68
-rw-r--r--internal/mapr/setclause.go20
-rw-r--r--internal/mapr/setcondition.go93
-rw-r--r--internal/mapr/token.go2
-rw-r--r--internal/mapr/whereclause.go77
-rw-r--r--internal/mapr/wherecondition.go27
-rw-r--r--internal/server/continuous.go121
-rw-r--r--internal/server/handlers/serverhandler.go4
-rw-r--r--internal/server/scheduler.go17
-rw-r--r--internal/server/server.go67
-rw-r--r--internal/user/server/user.go2
-rw-r--r--internal/version/version.go6
32 files changed, 642 insertions, 257 deletions
diff --git a/Makefile b/Makefile
index c358d8e..75b9333 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
GO ?= go
-all: build
+all: test build
build:
${GO} build -o dserver ./cmd/dserver/main.go
${GO} build -o dcat ./cmd/dcat/main.go
@@ -29,3 +29,5 @@ lint:
echo ${GOPATH}/bin/golint $$dir; \
${GOPATH}/bin/golint $$dir; \
done
+test:
+ ${GO} test ./... -v
diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go
index b8bcd06..3cec667 100644
--- a/cmd/dtail/main.go
+++ b/cmd/dtail/main.go
@@ -71,9 +71,12 @@ func main() {
version.PrintAndExit()
}
- ctx, _ := context.WithCancel(context.Background())
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
if shutdownAfter > 0 {
- ctx, _ = context.WithTimeout(ctx, time.Duration(shutdownAfter)*time.Second)
+ ctx, cancel = context.WithTimeout(ctx, time.Duration(shutdownAfter)*time.Second)
+ defer cancel()
}
if checkHealth {
diff --git a/go.mod b/go.mod
index d5c65b7..e95da7d 100644
--- a/go.mod
+++ b/go.mod
@@ -1,9 +1,10 @@
module github.com/mimecast/dtail
-go 1.13
+go 1.15
require (
- github.com/DataDog/zstd v1.4.4
- golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876
+ github.com/DataDog/zstd v1.4.5
+ golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de
golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect
+ golang.org/x/sys v0.0.0-20200812155832-6a926be9bd1d // indirect
)
diff --git a/go.sum b/go.sum
index b40dedd..00c6f54 100644
--- a/go.sum
+++ b/go.sum
@@ -1,9 +1,13 @@
github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE=
github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
+github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ=
+github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc=
golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de h1:ikNHVSjEfnvz6sxdSPCaPt572qowuyMDMJLLm3Db3ig=
+golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/lint v0.0.0-20200130185559-910be7a94367 h1:0IiAsCRByjO2QjX7ZPkw5oU9x+n1YqRL802rjC0c3Aw=
golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k=
@@ -15,6 +19,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200812155832-6a926be9bd1d h1:QQrM/CCYEzTs91GZylDCQjGHudbPTxF/1fvXdVh5lMo=
+golang.org/x/sys v0.0.0-20200812155832-6a926be9bd1d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7 h1:EBZoQjiKKPaLbPrbpssUfuHtwM6KV/vb4U85g/cigFY=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go
index 44daf7d..b908f3b 100644
--- a/internal/clients/handlers/maprhandler.go
+++ b/internal/clients/handlers/maprhandler.go
@@ -59,16 +59,11 @@ func (h *MaprHandler) Write(p []byte) (n int, err error) {
// related data.
func (h *MaprHandler) handleAggregateMessage(message string) {
h.count++
- parts := strings.Split(message, "|")
+ parts := strings.Split(message, "➔")
// Index 0 contains 'AGGREGATE', 1 contains server host.
// Aggregation data begins from index 2.
logger.Debug("Received aggregate data", h.server, h.count, parts)
- /*
- for k, v := range parts {
- logger.Debug(k, v)
- }
- */
h.aggregate.Aggregate(parts[2:])
logger.Debug("Aggregated aggregate data", h.server, h.count)
}
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
index e28dadb..c6c341b 100644
--- a/internal/clients/maprclient.go
+++ b/internal/clients/maprclient.go
@@ -14,11 +14,15 @@ import (
"github.com/mimecast/dtail/internal/omode"
)
+// MaprClientMode determines whether to use cumulative mode or not.
type MaprClientMode int
const (
- DefaultMode MaprClientMode = iota
- CumulativeMode MaprClientMode = iota
+ // DefaultMode behaviour
+ DefaultMode MaprClientMode = iota
+ // CumulativeMode means results are added to prev interval
+ CumulativeMode MaprClientMode = iota
+ // NonCumulativeMode means results are from 0 for each interval
NonCumulativeMode MaprClientMode = iota
)
@@ -60,6 +64,8 @@ func NewMaprClient(args Args, queryStr string, maprClientMode MaprClientMode) (*
cumulative = args.Mode == omode.MapClient || query.HasOutfile()
}
+ logger.Debug("Cumulative mapreduce mode?", cumulative)
+
c := MaprClient{
baseClient: baseClient{
Args: args,
diff --git a/internal/config/config.go b/internal/config/config.go
index 39149bc..dc96d6b 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -7,10 +7,13 @@ import (
)
// ControlUser is used for various DTail specific operations.
-const ControlUser string = "DTAIL-CONTROL-USER"
+const ControlUser string = "DTAIL-CONTROL"
-// BackgroundUser is used for non-interactive scheduled queries and log monitoring and such.
-const BackgroundUser string = "DTAIL-BACKGROUND-USER"
+// ScheduleUser is used for non-interactive scheduled mapreduce queries.
+const ScheduleUser string = "DTAIL-SCHEDULE"
+
+// ContinuousUser is used for non-interactive continuous mapreduce queries.
+const ContinuousUser string = "DTAIL-CONTINUOUS"
// Client holds a DTail client configuration.
var Client *ClientConfig
diff --git a/internal/config/server.go b/internal/config/server.go
index 4166fd3..83ff45f 100644
--- a/internal/config/server.go
+++ b/internal/config/server.go
@@ -14,7 +14,7 @@ type Permissions struct {
}
// JobCommons summarises common job fields
-type JobCommons struct {
+type jobCommons struct {
Name string
Enable bool
Files string
@@ -27,13 +27,13 @@ type JobCommons struct {
// Scheduled allows to configure scheduled mapreduce jobs.
type Scheduled struct {
- JobCommons
+ jobCommons
TimeRange [2]int
}
// Continuous allows to configure continuous running mapreduce jobs.
type Continuous struct {
- JobCommons
+ jobCommons
RestartOnDayChange bool `json:",omitempty"`
}
diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go
index d8705bd..a6cc6eb 100644
--- a/internal/mapr/aggregateset.go
+++ b/internal/mapr/aggregateset.go
@@ -2,7 +2,6 @@ package mapr
import (
"context"
- "encoding/base64"
"fmt"
"strconv"
"strings"
@@ -71,25 +70,20 @@ func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<-
var sb strings.Builder
sb.WriteString(groupKey)
- sb.WriteString("|")
- sb.WriteString(fmt.Sprintf("%d|", s.Samples))
+ sb.WriteString("➔")
+ sb.WriteString(fmt.Sprintf("%d➔", s.Samples))
for k, v := range s.FValues {
sb.WriteString(k)
sb.WriteString("=")
- sb.WriteString(fmt.Sprintf("%v|", v))
+ sb.WriteString(fmt.Sprintf("%v➔", v))
}
for k, v := range s.SValues {
sb.WriteString(k)
sb.WriteString("=")
- if k == "$line" {
- sb.WriteString(base64.StdEncoding.EncodeToString([]byte(v)))
- sb.WriteString("|")
- continue
- }
sb.WriteString(v)
- sb.WriteString("|")
+ sb.WriteString("➔")
}
select {
diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go
index e7fcdc6..10b34d4 100644
--- a/internal/mapr/client/aggregate.go
+++ b/internal/mapr/client/aggregate.go
@@ -1,7 +1,6 @@
package client
import (
- "encoding/base64"
"strconv"
"strings"
@@ -75,15 +74,6 @@ func (a *Aggregate) makeFields(parts []string) map[string]string {
if len(kv) < 2 {
continue
}
- if kv[0] == "$line" {
- decoded, err := base64.StdEncoding.DecodeString(kv[1])
- if err != nil {
- logger.Error("Unable to decode $line", kv[1], err)
- continue
- }
- fields[kv[0]] = string(decoded)
- continue
- }
fields[kv[0]] = kv[1]
}
diff --git a/internal/mapr/fieldtypes.go b/internal/mapr/fieldtypes.go
new file mode 100644
index 0000000..a64efd1
--- /dev/null
+++ b/internal/mapr/fieldtypes.go
@@ -0,0 +1,29 @@
+package mapr
+
+import "fmt"
+
+type fieldType int
+
+// The possible field types.
+const (
+ UndefFieldType fieldType = iota
+ Field fieldType = iota
+ String fieldType = iota
+ Float fieldType = iota
+ FunctionStack fieldType = iota
+)
+
+func (w fieldType) String() string {
+ switch w {
+ case Field:
+ return fmt.Sprintf("Field")
+ case String:
+ return fmt.Sprintf("String")
+ case Float:
+ return fmt.Sprintf("Float")
+ case FunctionStack:
+ return fmt.Sprintf("FunctionStack")
+ default:
+ return fmt.Sprintf("UndefFieldType")
+ }
+}
diff --git a/internal/mapr/funcs/function.go b/internal/mapr/funcs/function.go
new file mode 100644
index 0000000..52aaa98
--- /dev/null
+++ b/internal/mapr/funcs/function.go
@@ -0,0 +1,66 @@
+package funcs
+
+import (
+ "fmt"
+ "strings"
+)
+
+// CallbackFunc is a function which can be executed by the mapreduce engine
+type CallbackFunc func(text string) string
+
+// Function embeddes the function name to the callback function
+type Function struct {
+ // Name of the callback function
+ Name string
+ call CallbackFunc
+}
+
+// FunctionStack is a list of functions stacked each other
+type FunctionStack []Function
+
+// NewFunctionStack parses the input string, e.g. foo(bar("arg")) and returns a corresponding function stack.
+func NewFunctionStack(in string) (FunctionStack, string, error) {
+ var fs FunctionStack
+
+ getCallback := func(name string) (CallbackFunc, error) {
+ var cb CallbackFunc
+
+ switch name {
+ case "md5sum":
+ return Md5Sum, nil
+ case "maskdigits":
+ return MaskDigits, nil
+ default:
+ return cb, fmt.Errorf("unknown function '%s'", name)
+ }
+ }
+
+ aux := in
+ for strings.HasSuffix(aux, ")") {
+ index := strings.Index(aux, "(")
+ if index <= 0 {
+ return fs, "", fmt.Errorf("unable to parse function '%s' at '%s'", in, aux)
+ }
+ name := aux[0:index]
+
+ call, err := getCallback(name)
+ if err != nil {
+ return fs, "", err
+ }
+ fs = append(fs, Function{name, call})
+ aux = aux[index+1 : len(aux)-1]
+ }
+
+ return fs, aux, nil
+}
+
+// Call the function stack.
+func (fs FunctionStack) Call(str string) string {
+ for i := len(fs) - 1; i >= 0; i-- {
+ //logger.Debug("Call", fs[i].Name, str)
+ str = fs[i].call(str)
+ //logger.Debug("Call.result", fs[i].Name, str)
+ }
+
+ return str
+}
diff --git a/internal/mapr/funcs/function_test.go b/internal/mapr/funcs/function_test.go
new file mode 100644
index 0000000..415683c
--- /dev/null
+++ b/internal/mapr/funcs/function_test.go
@@ -0,0 +1,45 @@
+package funcs
+
+import "testing"
+
+func TestFunction(t *testing.T) {
+ input := "md5sum($line)"
+ fs, arg, err := NewFunctionStack(input)
+ if err != nil {
+ t.Errorf("error parsing function input '%s': %s (%v)\n", input, err.Error(), fs)
+ }
+ if arg != "$line" {
+ t.Errorf("error parsing function input '%s': expected argument '$line' but got '%s' (%v)\n", input, arg, fs)
+ }
+ t.Log(input, fs, arg)
+
+ result := fs.Call(input)
+ if result != "b38699013d79e50d9d122433753959c1" {
+ t.Errorf("error executing function stack '%s': expected result 'b38699013d79e50d9d122433753959c1' but got '%s' (%v)\n", input, result, fs)
+ }
+
+ input = "maskdigits(md5sum(maskdigits($line)))"
+ fs, arg, err = NewFunctionStack(input)
+ if err != nil {
+ t.Errorf("error parsing function input '%s': %s (%v)\n", input, err.Error(), fs)
+ }
+ if arg != "$line" {
+ t.Errorf("error parsing function input '%s': expected argument '$line' but got '%s' (%v)\n", input, arg, fs)
+ }
+ t.Log(input, fs, arg)
+
+ result = fs.Call(input)
+ if result != ".fac.bbe..bb.........d...a.c..b." {
+ t.Errorf("error executing function stack '%s': expected result '.fac.bbe..bb.........d...a.c..b.' but got '%s' (%v)\n", input, result, fs)
+ }
+
+ input = "md5sum$line)"
+ if fs, _, err := NewFunctionStack(input); err == nil {
+ t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", input, fs)
+ }
+
+ input = "md5sum(makedigits$line))"
+ if fs, _, err := NewFunctionStack(input); err == nil {
+ t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", input, fs)
+ }
+}
diff --git a/internal/mapr/funcs/maskdigits.go b/internal/mapr/funcs/maskdigits.go
new file mode 100644
index 0000000..d51f3d8
--- /dev/null
+++ b/internal/mapr/funcs/maskdigits.go
@@ -0,0 +1,14 @@
+package funcs
+
+// MaskDigits masks all digits (replaces them with .)
+func MaskDigits(input string) string {
+ s := []byte(input)
+
+ for i, b := range s {
+ if '0' <= b && b <= '9' {
+ s[i] = '.'
+ }
+ }
+
+ return string(s)
+}
diff --git a/internal/mapr/funcs/md5sum.go b/internal/mapr/funcs/md5sum.go
new file mode 100644
index 0000000..e3cc7e6
--- /dev/null
+++ b/internal/mapr/funcs/md5sum.go
@@ -0,0 +1,12 @@
+package funcs
+
+import (
+ "crypto/md5"
+ "encoding/hex"
+)
+
+// Md5Sum returns the hex encoded MD5 checksum of a given input string.
+func Md5Sum(text string) string {
+ hash := md5.Sum([]byte(text))
+ return hex.EncodeToString(hash[:])
+}
diff --git a/internal/mapr/logformat/default.go b/internal/mapr/logformat/default.go
index 0dfdde0..44bf558 100644
--- a/internal/mapr/logformat/default.go
+++ b/internal/mapr/logformat/default.go
@@ -24,5 +24,6 @@ func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) {
}
fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1]
}
+
return fields, nil
}
diff --git a/internal/mapr/logformat/default_test.go b/internal/mapr/logformat/default_test.go
index a3c47fb..d7a4da4 100644
--- a/internal/mapr/logformat/default_test.go
+++ b/internal/mapr/logformat/default_test.go
@@ -5,7 +5,7 @@ import (
)
func TestDefaultLogFormat(t *testing.T) {
- parser, err := NewParser("default")
+ parser, err := NewParser("default", nil)
if err != nil {
t.Errorf("Unable to create parser: %s", err.Error())
}
diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go
index cc9c268..c53729a 100644
--- a/internal/mapr/logformat/parser.go
+++ b/internal/mapr/logformat/parser.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/mapr"
)
// Parser is used to parse the mapreduce information from the server log files.
@@ -22,7 +23,7 @@ type Parser struct {
}
// NewParser returns a new log parser.
-func NewParser(logFormatName string) (*Parser, error) {
+func NewParser(logFormatName string, query *mapr.Query) (*Parser, error) {
hostname, err := os.Hostname()
if err != nil {
diff --git a/internal/mapr/query.go b/internal/mapr/query.go
index 6dff792..7f6b63c 100644
--- a/internal/mapr/query.go
+++ b/internal/mapr/query.go
@@ -20,6 +20,7 @@ type Query struct {
Select []selectCondition
Table string
Where []whereCondition
+ Set []setCondition
GroupBy []string
OrderBy string
ReverseOrder bool
@@ -29,13 +30,15 @@ type Query struct {
Outfile string
RawQuery string
tokens []token
+ LogFormat string
}
func (q Query) String() string {
- return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,GroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v)",
+ return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,Set:%vGroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v,LogFormat:%s)",
q.Select,
q.Table,
q.Where,
+ q.Set,
q.GroupBy,
q.GroupKey,
q.OrderBy,
@@ -44,7 +47,8 @@ func (q Query) String() string {
q.Limit,
q.Outfile,
q.RawQuery,
- q.tokens)
+ q.tokens,
+ q.LogFormat)
}
// NewQuery returns a new mapreduce query.
@@ -68,10 +72,16 @@ func NewQuery(queryStr string) (*Query, error) {
return &q, err
}
+// HasOutfile returns true if query result will be written to a CVS output file.
func (q *Query) HasOutfile() bool {
return q.Outfile != ""
}
+// Has is a helper to determine whether a query contains a substring
+func (q *Query) Has(what string) bool {
+ return strings.Contains(q.RawQuery, what)
+}
+
func (q *Query) parse(tokens []token) error {
var found []token
var err error
@@ -86,14 +96,23 @@ func (q *Query) parse(tokens []token) error {
}
case "from":
tokens, found = tokensConsume(tokens[1:])
- if len(found) > 0 {
- q.Table = strings.ToUpper(found[0].str)
+ if len(found) == 0 {
+ return errors.New(invalidQuery + "expected table name after 'from'")
+ }
+ if len(found) > 1 {
+ return errors.New(invalidQuery + "expected only one table name after 'from'")
}
+ q.Table = strings.ToUpper(found[0].str)
case "where":
tokens, found = tokensConsume(tokens[1:])
if q.Where, err = makeWhereConditions(found); err != nil {
return err
}
+ case "set":
+ tokens, found = tokensConsume(tokens[1:])
+ if q.Set, err = makeSetConditions(found); err != nil {
+ return err
+ }
case "group":
tokens = tokensConsumeOptional(tokens[1:], "by")
if tokens == nil || len(tokens) < 1 {
@@ -147,6 +166,12 @@ func (q *Query) parse(tokens []token) error {
return errors.New(invalidQuery + unexpectedEnd)
}
q.Outfile = found[0].str
+ case "logformat":
+ tokens, found = tokensConsume(tokens[1:])
+ if len(found) == 0 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ q.LogFormat = found[0].str
default:
return errors.New(invalidQuery + "Unexpected keyword " + tokens[0].str)
}
@@ -181,73 +206,3 @@ func (q *Query) parse(tokens []token) error {
return nil
}
-
-// WhereClause interprets the where clause of the mapreduce query.
-func (q *Query) WhereClause(fields map[string]string) bool {
- floatValue := func(str string, float float64, t whereType) (float64, bool) {
- switch t {
- case Float:
- return float, true
- case Field:
- value, ok := fields[str]
- if !ok {
- return 0, false
- }
- f, err := strconv.ParseFloat(value, 64)
- if err != nil {
- return 0, false
- }
- return f, true
- default:
- logger.Error("Unexpected argument in 'where' clause", str, float, t)
- return 0, false
- }
- }
-
- stringValue := func(str string, t whereType) (string, bool) {
- switch t {
- case Field:
- value, ok := fields[str]
- if !ok {
- return str, false
- }
- return value, true
- case String:
- return str, true
- default:
- logger.Error("Unexpected argument in 'where' clause", str, t)
- return str, false
- }
- }
-
- for _, wc := range q.Where {
- var ok bool
-
- if wc.Operation > FloatOperation {
- var lValue, rValue float64
- if lValue, ok = floatValue(wc.lString, wc.lFloat, wc.lType); !ok {
- return false
- }
- if rValue, ok = floatValue(wc.rString, wc.rFloat, wc.rType); !ok {
- return false
- }
- if ok = wc.floatClause(lValue, rValue); !ok {
- return false
- }
- continue
- }
-
- var lValue, rValue string
- if lValue, ok = stringValue(wc.lString, wc.lType); !ok {
- return false
- }
- if rValue, ok = stringValue(wc.rString, wc.rType); !ok {
- return false
- }
- if ok = wc.stringClause(lValue, rValue); !ok {
- return false
- }
- }
-
- return true
-}
diff --git a/internal/mapr/query_test.go b/internal/mapr/query_test.go
index 6176461..b0b6c3a 100644
--- a/internal/mapr/query_test.go
+++ b/internal/mapr/query_test.go
@@ -8,13 +8,13 @@ import (
func TestParseQuerySimple(t *testing.T) {
errorQueries := []string{
"select",
- "select foo",
"select foo from",
"select foo from bar where baz",
"select foo from bar where baz <",
"select foo from bar where baz < 100 bay eq 12 group",
"select foo from bar where baz < 100 bay eq 12 group by foo order by",
"select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit set foo = bar;",
}
okQueries := []string{"select foo from bar",
"select foo from bar where",
@@ -24,6 +24,7 @@ func TestParseQuerySimple(t *testing.T) {
"select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo",
"select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23",
"select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\"",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\" set $foo = maskdigits(bar), $baz = 12, $bay = $foo;",
}
for _, queryStr := range errorQueries {
@@ -45,13 +46,8 @@ func TestParseQuerySimple(t *testing.T) {
func TestParseQueryDeep(t *testing.T) {
dialects := []string{
- "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23",
- "SELECT s1, `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23",
- "select s1, `from` count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23",
- "sElEct s1, `from` coUnt(s3) from taBle where w1 == 2 aNd w2 eq \"free beer\" Group By g1, g2 order bY count(s3) intervaL 10 LiMiT 23",
- "SELECT s1 `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP BY g1 g2 ORDER BY count(s3) INTERVAL 10 LIMIT 23",
- "select s1 `from` count(s3) from table where w1 == 2 w2 eq \"free beer\" group g1 g2 order count(s3) interval 10 limit 23",
- "limit 23 interval 10 order count(s3) group g1 g2 where w1 == 2 w2 eq \"free beer\" from table select s1 `from` count(s3)",
+ "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23 set $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic",
+ "SELECT s1, `from`, COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23 SET $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic",
}
for _, queryStr := range dialects {
@@ -60,6 +56,8 @@ func TestParseQueryDeep(t *testing.T) {
t.Errorf("%s: %s", err.Error(), queryStr)
}
+ t.Log(q)
+
// 'select' clause
if len(q.Select) != 3 {
t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v", q.Select, queryStr, q)
@@ -145,5 +143,31 @@ func TestParseQueryDeep(t *testing.T) {
if q.Limit != 23 {
t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v", q.Limit, queryStr, q)
}
+
+ // 'set' clause
+ if q.Set[0].lString != "$foo" {
+ t.Errorf("Expected '$foo' lvalue in first 'set' condition clause but got '%v': %s\n%v", q.Set[0].lString, queryStr, q)
+ }
+ if q.Set[0].rString != "bar" {
+ t.Errorf("Expected 'bar' rvalue in first 'set' condition clause but got '%v': %s\n%v", q.Set[0].rString, queryStr, q)
+ }
+
+ if q.Set[1].lString != "$baz" {
+ t.Errorf("Expected '$baz' lvalue in second 'set' condition clause but got '%v': %s\n%v", q.Set[1].lString, queryStr, q)
+ }
+ if q.Set[1].rString != "12" {
+ t.Errorf("Expected '12' rvalue in second 'set' condition clause but got '%v': %s\n%v", q.Set[1].rString, queryStr, q)
+ }
+
+ if q.Set[2].lString != "$bay" {
+ t.Errorf("Expected '$bay' lvalue in third 'set' condition clause but got '%v': %s\n%v", q.Set[2].lString, queryStr, q)
+ }
+ if q.Set[2].rString != "$foo" {
+ t.Errorf("Expected '$foo' rvalue in third 'set' condition clause but got '%v': %s\n%v", q.Set[2].rString, queryStr, q)
+ }
+
+ if q.LogFormat != "generic" {
+ t.Errorf("Expected 'generic' logformat got '%v': %s\n%v", q.LogFormat, queryStr, q)
+ }
}
}
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
index 80a464d..1028943 100644
--- a/internal/mapr/server/aggregate.go
+++ b/internal/mapr/server/aggregate.go
@@ -44,15 +44,24 @@ func NewAggregate(queryStr string) (*Aggregate, error) {
}
s := strings.Split(fqdn, ".")
- parserName := config.Server.MapreduceLogFormat
- if query.Table == "" {
- parserName = "generic"
+ var parserName string
+ switch query.LogFormat {
+ case "":
+ parserName = config.Server.MapreduceLogFormat
+ if query.Table == "" {
+ parserName = "generic"
+ }
+ default:
+ parserName = query.LogFormat
}
- logger.Info("Creating mapr log format parser", parserName)
- logParser, err := logformat.NewParser(parserName)
+ logger.Info("Creating log format parser", parserName)
+ logParser, err := logformat.NewParser(parserName, query)
if err != nil {
- logger.FatalExit("Could not create mapr log format parser", err)
+ logger.Error("Could not create log format parser. Falling back to 'generic'", err)
+ if logParser, err = logformat.NewParser("generic", query); err != nil {
+ logger.FatalExit("Could not create log format parser", err)
+ }
}
ctx, cancel := context.WithCancel(context.Background())
@@ -76,6 +85,12 @@ func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) {
defer a.cancel()
fieldsCh := a.linesToFields(ctx)
+
+ // Add fields (e.g. via 'set' clause)
+ if len(a.query.Set) > 0 {
+ fieldsCh = a.addMoreFields(ctx, fieldsCh)
+ }
+
go a.fieldsToMaprLines(ctx, fieldsCh, maprLines)
a.periodicAggregateTimer(ctx)
}
@@ -99,10 +114,10 @@ func (a *Aggregate) periodicAggregateTimer(ctx context.Context) {
}
func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string {
- fieldsCh := make(chan map[string]string)
+ ch := make(chan map[string]string)
go func() {
- defer close(fieldsCh)
+ defer close(ch)
for {
select {
@@ -113,6 +128,7 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
maprLine := strings.TrimSpace(string(line.Content))
fields, err := a.parser.MakeFields(maprLine)
+ logger.Debug(fields, err)
if err != nil {
logger.Error(err)
@@ -123,7 +139,7 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
}
select {
- case fieldsCh <- fields:
+ case ch <- fields:
case <-ctx.Done():
}
case <-ctx.Done():
@@ -134,7 +150,33 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
}
}()
- return fieldsCh
+ return ch
+}
+
+func (a *Aggregate) addMoreFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string {
+ ch := make(chan map[string]string)
+
+ go func() {
+ defer close(ch)
+
+ for {
+ // fieldsCh will be closed via 'linesToFields' if ctx is done
+ fields, ok := <-fieldsCh
+ if !ok {
+ return
+ }
+ if err := a.query.SetClause(fields); err != nil {
+ logger.Error(err)
+ }
+
+ select {
+ case ch <- fields:
+ case <-ctx.Done():
+ }
+ }
+ }()
+
+ return ch
}
func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
@@ -192,12 +234,6 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
var addedSample bool
for _, sc := range a.query.Select {
if val, ok := fields[sc.Field]; ok {
- /*
- if sc.Field == "$line" {
- // Complete log line as to arrive untouched on the client side.
- val = base64.StdEncoding.EncodeToString([]byte(val))
- }
- */
if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, false); err != nil {
logger.Error(err)
continue
diff --git a/internal/mapr/setclause.go b/internal/mapr/setclause.go
new file mode 100644
index 0000000..b4c2f73
--- /dev/null
+++ b/internal/mapr/setclause.go
@@ -0,0 +1,20 @@
+package mapr
+
+// SetClause interprets the set clause of the mapreduce query.
+func (q *Query) SetClause(fields map[string]string) error {
+ for _, sc := range q.Set {
+ value, ok := fields[sc.rString]
+ if !ok {
+ continue
+ }
+
+ switch sc.rType {
+ case FunctionStack:
+ fields[sc.lString] = sc.functionStack.Call(value)
+ default:
+ fields[sc.lString] = value
+ }
+ }
+
+ return nil
+}
diff --git a/internal/mapr/setcondition.go b/internal/mapr/setcondition.go
new file mode 100644
index 0000000..8c5cfc9
--- /dev/null
+++ b/internal/mapr/setcondition.go
@@ -0,0 +1,93 @@
+package mapr
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/mapr/funcs"
+)
+
+// Represent a parsed "set" clause, used by mapr.Query
+type setCondition struct {
+ lString string
+
+ rType fieldType
+ rString string
+ rFloat float64
+
+ // For now only text functions are supported.
+ // Maybe in the future we can have typed functions too
+ // so that a float input/output is possible.
+ functionStack funcs.FunctionStack
+}
+
+func (sc *setCondition) String() string {
+ return fmt.Sprintf("setCondition(lString:%s,rString:%s,rType:%s,functionStack:%v)",
+ sc.lString, sc.rString, sc.rType.String(), sc.functionStack)
+}
+
+func makeSetConditions(tokens []token) (set []setCondition, err error) {
+ parse := func(tokens []token) (setCondition, []token, error) {
+ var sc setCondition
+ if len(tokens) < 3 {
+ return sc, nil, errors.New(invalidQuery + "Not enough arguments in 'set' clause")
+ }
+
+ setOp := strings.ToLower(tokens[1].str)
+ switch setOp {
+ case "=":
+ default:
+ return sc, nil, errors.New(invalidQuery + "Unknown operation in 'set' clause: " + setOp)
+ }
+
+ if !tokens[0].isBareword {
+ return sc, nil, errors.New(invalidQuery + "Expected bareword at 'set' clause's lValue: " + tokens[0].str)
+ }
+
+ sc.lString = tokens[0].str
+ if !strings.HasPrefix(sc.lString, "$") {
+ return sc, nil, errors.New(invalidQuery + "Expected field variable name (starting with $) at 'set' clause's lValue: " + tokens[0].str)
+ }
+ sc.rType = Field
+
+ rString := tokens[2].str
+ // Seems like a function call?
+ if strings.HasSuffix(rString, ")") {
+ functionStack, functionArg, err := funcs.NewFunctionStack(tokens[2].str)
+ if err != nil {
+ return sc, nil, err
+ }
+ sc.functionStack = functionStack
+ sc.rType = FunctionStack
+ sc.rString = functionArg
+ return sc, tokens[3:], nil
+ }
+
+ sc.rString = rString
+ if f, err := strconv.ParseFloat(sc.rString, 64); err == nil {
+ sc.rFloat = f
+ sc.rType = Float
+ } else {
+ sc.rType = Field
+ }
+
+ return sc, tokens[3:], nil
+ }
+
+ for len(tokens) > 0 {
+ var sc setCondition
+ var err error
+
+ sc, tokens, err = parse(tokens)
+ if err != nil {
+ return nil, err
+ }
+
+ set = append(set, sc)
+ tokens = tokensConsumeOptional(tokens, ",")
+ }
+
+ return
+}
diff --git a/internal/mapr/token.go b/internal/mapr/token.go
index b8be4da..d337bd2 100644
--- a/internal/mapr/token.go
+++ b/internal/mapr/token.go
@@ -4,7 +4,7 @@ import (
"strings"
)
-var keywords = [...]string{"select", "from", "where", "group", "rorder", "order", "interval", "limit", "outfile"}
+var keywords = [...]string{"select", "from", "where", "set", "group", "rorder", "order", "interval", "limit", "outfile", "logformat"}
// Represents a parsed token, used to parse the mapr query.
type token struct {
diff --git a/internal/mapr/whereclause.go b/internal/mapr/whereclause.go
new file mode 100644
index 0000000..cc1c164
--- /dev/null
+++ b/internal/mapr/whereclause.go
@@ -0,0 +1,77 @@
+package mapr
+
+import (
+ "strconv"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+)
+
+// WhereClause interprets the where clause of the mapreduce query.
+func (q *Query) WhereClause(fields map[string]string) bool {
+ for _, wc := range q.Where {
+ var ok bool
+
+ if wc.Operation > FloatOperation {
+ var lValue, rValue float64
+ if lValue, ok = whereClauseFloatValue(fields, wc.lString, wc.lFloat, wc.lType); !ok {
+ return false
+ }
+ if rValue, ok = whereClauseFloatValue(fields, wc.rString, wc.rFloat, wc.rType); !ok {
+ return false
+ }
+ if ok = wc.floatClause(lValue, rValue); !ok {
+ return false
+ }
+ continue
+ }
+
+ var lValue, rValue string
+ if lValue, ok = whereClauseStringValue(fields, wc.lString, wc.lType); !ok {
+ return false
+ }
+ if rValue, ok = whereClauseStringValue(fields, wc.rString, wc.rType); !ok {
+ return false
+ }
+ if ok = wc.stringClause(lValue, rValue); !ok {
+ return false
+ }
+ }
+
+ return true
+}
+
+func whereClauseFloatValue(fields map[string]string, str string, float float64, t fieldType) (float64, bool) {
+ switch t {
+ case Float:
+ return float, true
+ case Field:
+ value, ok := fields[str]
+ if !ok {
+ return 0, false
+ }
+ f, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return 0, false
+ }
+ return f, true
+ default:
+ logger.Error("Unexpected argument in 'where' clause", str, float, t)
+ return 0, false
+ }
+}
+
+func whereClauseStringValue(fields map[string]string, str string, t fieldType) (string, bool) {
+ switch t {
+ case Field:
+ value, ok := fields[str]
+ if !ok {
+ return str, false
+ }
+ return value, true
+ case String:
+ return str, true
+ default:
+ logger.Error("Unexpected argument in 'where' clause", str, t)
+ return str, false
+ }
+}
diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go
index 3ca9103..ff1b489 100644
--- a/internal/mapr/wherecondition.go
+++ b/internal/mapr/wherecondition.go
@@ -28,40 +28,17 @@ const (
FloatGe QueryOperation = iota
)
-type whereType int
-
-// The possible field types.
-const (
- UndefWhereType whereType = iota
- Field whereType = iota
- String whereType = iota
- Float whereType = iota
-)
-
-func (w whereType) String() string {
- switch w {
- case Field:
- return fmt.Sprintf("Field")
- case String:
- return fmt.Sprintf("String")
- case Float:
- return fmt.Sprintf("Float")
- default:
- return fmt.Sprintf("UndefWhereType")
- }
-}
-
// Represent a parsed "where" clause, used by mapr.Query
type whereCondition struct {
+ lType fieldType
lString string
lFloat float64
- lType whereType
Operation QueryOperation
+ rType fieldType
rString string
rFloat float64
- rType whereType
}
func (wc *whereCondition) String() string {
diff --git a/internal/server/continuous.go b/internal/server/continuous.go
index cf89cdd..f3993a1 100644
--- a/internal/server/continuous.go
+++ b/internal/server/continuous.go
@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
- "os"
"strings"
"time"
@@ -22,71 +21,97 @@ func newContinuous() *continuous {
return &continuous{}
}
-func (s *continuous) start(ctx context.Context) {
- // First run after just 10s!
+func (c *continuous) start(ctx context.Context) {
+ logger.Info("Starting continuous job runner after 10s")
time.Sleep(time.Second * 10)
- s.runJobs(ctx)
- for {
- select {
- case <-time.After(time.Minute):
- s.runJobs(ctx)
- case <-ctx.Done():
- return
- }
- }
+ c.runJobs(ctx)
}
-func (s *continuous) runJobs(ctx context.Context) {
- for _, job := range config.Server.Schedule {
+func (c *continuous) runJobs(ctx context.Context) {
+ for _, job := range config.Server.Continuous {
if !job.Enable {
logger.Debug(job.Name, "Not running job as not enabled")
continue
}
- files := fillDates(job.Files)
- outfile := fillDates(job.Outfile)
+ go func(job config.Continuous) {
+ c.runJob(ctx, job)
+ for {
+ select {
+ // Retry after a minute
+ case <-time.After(time.Minute):
+ c.runJob(ctx, job)
+ case <-ctx.Done():
+ return
+ }
+ }
+ }(job)
+ }
+}
- servers := strings.Join(job.Servers, ",")
- if servers == "" {
- servers = config.Server.SSHBindAddress
- }
+func (c *continuous) runJob(ctx context.Context, job config.Continuous) {
+ logger.Debug(job.Name, "Processing job")
- args := clients.Args{
- ConnectionsPerCPU: 10,
- Discovery: job.Discovery,
- ServersStr: servers,
- What: files,
- Mode: omode.MapClient,
- UserName: config.BackgroundUser,
- }
+ files := fillDates(job.Files)
+ outfile := fillDates(job.Outfile)
- args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name))
+ servers := strings.Join(job.Servers, ",")
+ if servers == "" {
+ servers = config.Server.SSHBindAddress
+ }
- tmpOutfile := fmt.Sprintf("%s.tmp", outfile)
- query := fmt.Sprintf("%s outfile %s", job.Query, tmpOutfile)
+ args := clients.Args{
+ ConnectionsPerCPU: 10,
+ Discovery: job.Discovery,
+ ServersStr: servers,
+ What: files,
+ Mode: omode.TailClient,
+ UserName: config.ContinuousUser,
+ }
- client, err := clients.NewMaprClient(args, query, clients.NonCumulativeMode)
- if err != nil {
- logger.Error(fmt.Sprintf("Unable to create job job %s", job.Name), err)
- continue
- }
+ args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name))
- jobCtx, cancel := context.WithCancel(ctx)
- defer cancel()
+ query := fmt.Sprintf("%s outfile %s", job.Query, outfile)
+ client, err := clients.NewMaprClient(args, query, clients.NonCumulativeMode)
+ if err != nil {
+ logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
+ return
+ }
- logger.Info(fmt.Sprintf("Starting job job %s", job.Name))
- status := client.Start(jobCtx)
- logMessage := fmt.Sprintf("Job exited with status %d", status)
+ jobCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
- if err := os.Rename(tmpOutfile, outfile); err == nil {
- logger.Info(job.Name, fmt.Sprintf("Renamed %s to %s", tmpOutfile, outfile))
- }
+ if job.RestartOnDayChange {
+ go func() {
+ if c.waitForDayChange(ctx) {
+ logger.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name))
+ cancel()
+ }
+ }()
+ }
- if status != 0 {
- logger.Warn(logMessage)
- continue
+ logger.Info(fmt.Sprintf("Starting job %s", job.Name))
+ status := client.Start(jobCtx)
+ logMessage := fmt.Sprintf("Job exited with status %d", status)
+
+ if status != 0 {
+ logger.Warn(logMessage)
+ return
+ }
+ logger.Info(logMessage)
+}
+
+func (c *continuous) waitForDayChange(ctx context.Context) bool {
+ startTime := time.Now()
+ for {
+ select {
+ case <-time.After(time.Second):
+ if time.Now().Day() != startTime.Day() {
+ return true
+ }
+ case <-ctx.Done():
+ return false
}
- logger.Info(logMessage)
}
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 939388c..9b52c85 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -101,7 +101,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) {
case message := <-h.aggregatedMessages:
// Send mapreduce-aggregated data as a message.
- data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message)
+ data := fmt.Sprintf("AGGREGATE➔%s➔%s\n", h.hostname, message)
wholePayload := []byte(data)
n = copy(p, wholePayload)
return
@@ -192,7 +192,7 @@ func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, err
}
if args[1] != version.ProtocolCompat {
- err := fmt.Errorf("server with protool version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1])
+ err := fmt.Errorf("server with protocol version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1])
return args, argc, err
}
diff --git a/internal/server/scheduler.go b/internal/server/scheduler.go
index e75077e..3345d69 100644
--- a/internal/server/scheduler.go
+++ b/internal/server/scheduler.go
@@ -24,6 +24,7 @@ func newScheduler() *scheduler {
}
func (s *scheduler) start(ctx context.Context) {
+ logger.Info("Starting scheduled job runner after 10s")
// First run after just 10s!
time.Sleep(time.Second * 10)
s.runJobs(ctx)
@@ -47,7 +48,7 @@ func (s *scheduler) runJobs(ctx context.Context) {
hour, err := strconv.Atoi(time.Now().Format("15"))
if err != nil {
- logger.Error(job.Name, "Unable to create job job", err)
+ logger.Error(job.Name, "Unable to create job", err)
continue
}
@@ -76,31 +77,25 @@ func (s *scheduler) runJobs(ctx context.Context) {
ServersStr: servers,
What: files,
Mode: omode.MapClient,
- UserName: config.BackgroundUser,
+ UserName: config.ScheduleUser,
}
args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name))
- tmpOutfile := fmt.Sprintf("%s.tmp", outfile)
- query := fmt.Sprintf("%s outfile %s", job.Query, tmpOutfile)
-
+ query := fmt.Sprintf("%s outfile %s", job.Query, outfile)
client, err := clients.NewMaprClient(args, query, clients.CumulativeMode)
if err != nil {
- logger.Error(fmt.Sprintf("Unable to create job job %s", job.Name), err)
+ logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
continue
}
jobCtx, cancel := context.WithCancel(ctx)
defer cancel()
- logger.Info(fmt.Sprintf("Starting job job %s", job.Name))
+ logger.Info(fmt.Sprintf("Starting job %s", job.Name))
status := client.Start(jobCtx)
logMessage := fmt.Sprintf("Job exited with status %d", status)
- if err := os.Rename(tmpOutfile, outfile); err == nil {
- logger.Info(job.Name, fmt.Sprintf("Renamed %s to %s", tmpOutfile, outfile))
- }
-
if status != 0 {
logger.Warn(logMessage)
continue
diff --git a/internal/server/server.go b/internal/server/server.go
index 486b8c6..693c48d 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -54,7 +54,7 @@ func New() *Server {
background: background.New(),
}
- s.sshServerConfig.PasswordCallback = s.backgroundUserCallback
+ s.sshServerConfig.PasswordCallback = s.Callback
s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback
private, err := gossh.ParsePrivateKey(server.PrivateHostKey())
@@ -241,44 +241,59 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
return nil
}
-func (s *Server) backgroundUserCallback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
+// Callback for SSH authentication.
+func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
user := user.New(c.User(), c.RemoteAddr().String())
authInfo := string(authPayload)
- if user.Name == config.ControlUser && authInfo == config.ControlUser {
- logger.Debug(user, "Granting permissions to control user")
- return nil, nil
- }
+ splitted := strings.Split(c.RemoteAddr().String(), ":")
+ remoteIP := splitted[0]
- if user.Name == config.BackgroundUser && s.backgroundJobUserCanHaveSSHSession(c.RemoteAddr().String(), user, authInfo) {
- logger.Debug(user, "Granting SSH connection to background user")
- return nil, nil
+ switch user.Name {
+ case config.ControlUser:
+ if authInfo == config.ControlUser {
+ logger.Debug(user, "Granting permissions to control user")
+ return nil, nil
+ }
+ case config.ScheduleUser:
+ for _, job := range config.Server.Schedule {
+ if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) {
+ logger.Debug(user, "Granting SSH connection")
+ return nil, nil
+ }
+ }
+ case config.ContinuousUser:
+ for _, job := range config.Server.Continuous {
+ if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) {
+ logger.Debug(user, "Granting SSH connection")
+ return nil, nil
+ }
+ }
+ default:
}
return nil, fmt.Errorf("user %s not authorized", user)
}
-func (s *Server) backgroundJobUserCanHaveSSHSession(addr string, user *user.User, jobName string) bool {
- logger.Debug("backgroundJobUserCanHaveSSHSession", user, jobName)
- splitted := strings.Split(addr, ":")
- ip := splitted[0]
+func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP, allowedJobName string, allowFrom []string) bool {
+ logger.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom)
- for _, job := range config.Server.Schedule {
- if job.Name != jobName {
+ if jobName != allowedJobName {
+ logger.Debug(user, jobName, "backgroundCanSSH", "Job name does not match, skipping to next one...", allowedJobName)
+ return false
+ }
+
+ for _, myAddr := range allowFrom {
+ ips, err := net.LookupIP(myAddr)
+ if err != nil {
+ logger.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP address for allowed hosts lookup, skipping to next one...", myAddr, err)
continue
}
- for _, myAddr := range job.AllowFrom {
- myIPs, err := net.LookupIP(myAddr)
- if err != nil {
- logger.Error(user, myAddr, err)
- continue
- }
- for _, myIP := range myIPs {
- logger.Debug("backgroundJobUserCanHaveSSHSession", "Comparing IP addresses", ip, myIP.String())
- if ip == myIP.String() {
- return true
- }
+ for _, ip := range ips {
+ logger.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses", remoteIP, ip.String())
+ if remoteIP == ip.String() {
+ return true
}
}
}
diff --git a/internal/user/server/user.go b/internal/user/server/user.go
index 29158df..c4e8b7b 100644
--- a/internal/user/server/user.go
+++ b/internal/user/server/user.go
@@ -41,7 +41,7 @@ func (u *User) String() string {
func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission bool) {
logger.Debug(u, filePath, permissionType, "Checking config permissions")
- if u.Name == config.BackgroundUser {
+ if u.Name == config.ScheduleUser || u.Name == config.ContinuousUser {
// Background user has same permissions as dtail process itself.
return true
}
diff --git a/internal/version/version.go b/internal/version/version.go
index f5844f9..f97c4bb 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -11,11 +11,11 @@ const (
// Name of DTail.
Name string = "DTail"
// Version of DTail.
- Version string = "2.5.0"
+ Version string = "3.0.0"
// Additional information for DTail
- Additional string = "develop"
+ Additional string = ""
// ProtocolCompat -ibility version.
- ProtocolCompat string = "2"
+ ProtocolCompat string = "3"
)
// String representation of the DTail version.