summaryrefslogtreecommitdiff
path: root/internal/mapr
diff options
context:
space:
mode:
Diffstat (limited to 'internal/mapr')
-rw-r--r--internal/mapr/aggregateset.go14
-rw-r--r--internal/mapr/client/aggregate.go10
-rw-r--r--internal/mapr/fieldtypes.go29
-rw-r--r--internal/mapr/funcs/function.go66
-rw-r--r--internal/mapr/funcs/function_test.go45
-rw-r--r--internal/mapr/funcs/maskdigits.go14
-rw-r--r--internal/mapr/funcs/md5sum.go12
-rw-r--r--internal/mapr/logformat/default.go1
-rw-r--r--internal/mapr/logformat/default_test.go2
-rw-r--r--internal/mapr/logformat/parser.go3
-rw-r--r--internal/mapr/query.go103
-rw-r--r--internal/mapr/query_test.go40
-rw-r--r--internal/mapr/server/aggregate.go68
-rw-r--r--internal/mapr/setclause.go20
-rw-r--r--internal/mapr/setcondition.go93
-rw-r--r--internal/mapr/token.go2
-rw-r--r--internal/mapr/whereclause.go77
-rw-r--r--internal/mapr/wherecondition.go27
18 files changed, 480 insertions, 146 deletions
diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go
index d8705bd..a6cc6eb 100644
--- a/internal/mapr/aggregateset.go
+++ b/internal/mapr/aggregateset.go
@@ -2,7 +2,6 @@ package mapr
import (
"context"
- "encoding/base64"
"fmt"
"strconv"
"strings"
@@ -71,25 +70,20 @@ func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<-
var sb strings.Builder
sb.WriteString(groupKey)
- sb.WriteString("|")
- sb.WriteString(fmt.Sprintf("%d|", s.Samples))
+ sb.WriteString("āž”")
+ sb.WriteString(fmt.Sprintf("%dāž”", s.Samples))
for k, v := range s.FValues {
sb.WriteString(k)
sb.WriteString("=")
- sb.WriteString(fmt.Sprintf("%v|", v))
+ sb.WriteString(fmt.Sprintf("%vāž”", v))
}
for k, v := range s.SValues {
sb.WriteString(k)
sb.WriteString("=")
- if k == "$line" {
- sb.WriteString(base64.StdEncoding.EncodeToString([]byte(v)))
- sb.WriteString("|")
- continue
- }
sb.WriteString(v)
- sb.WriteString("|")
+ sb.WriteString("āž”")
}
select {
diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go
index e7fcdc6..10b34d4 100644
--- a/internal/mapr/client/aggregate.go
+++ b/internal/mapr/client/aggregate.go
@@ -1,7 +1,6 @@
package client
import (
- "encoding/base64"
"strconv"
"strings"
@@ -75,15 +74,6 @@ func (a *Aggregate) makeFields(parts []string) map[string]string {
if len(kv) < 2 {
continue
}
- if kv[0] == "$line" {
- decoded, err := base64.StdEncoding.DecodeString(kv[1])
- if err != nil {
- logger.Error("Unable to decode $line", kv[1], err)
- continue
- }
- fields[kv[0]] = string(decoded)
- continue
- }
fields[kv[0]] = kv[1]
}
diff --git a/internal/mapr/fieldtypes.go b/internal/mapr/fieldtypes.go
new file mode 100644
index 0000000..a64efd1
--- /dev/null
+++ b/internal/mapr/fieldtypes.go
@@ -0,0 +1,29 @@
+package mapr
+
+import "fmt"
+
+type fieldType int
+
+// The possible field types.
+const (
+ UndefFieldType fieldType = iota
+ Field fieldType = iota
+ String fieldType = iota
+ Float fieldType = iota
+ FunctionStack fieldType = iota
+)
+
+func (w fieldType) String() string {
+ switch w {
+ case Field:
+ return fmt.Sprintf("Field")
+ case String:
+ return fmt.Sprintf("String")
+ case Float:
+ return fmt.Sprintf("Float")
+ case FunctionStack:
+ return fmt.Sprintf("FunctionStack")
+ default:
+ return fmt.Sprintf("UndefFieldType")
+ }
+}
diff --git a/internal/mapr/funcs/function.go b/internal/mapr/funcs/function.go
new file mode 100644
index 0000000..52aaa98
--- /dev/null
+++ b/internal/mapr/funcs/function.go
@@ -0,0 +1,66 @@
+package funcs
+
+import (
+ "fmt"
+ "strings"
+)
+
+// CallbackFunc is a function which can be executed by the mapreduce engine
+type CallbackFunc func(text string) string
+
+// Function embeddes the function name to the callback function
+type Function struct {
+ // Name of the callback function
+ Name string
+ call CallbackFunc
+}
+
+// FunctionStack is a list of functions stacked each other
+type FunctionStack []Function
+
+// NewFunctionStack parses the input string, e.g. foo(bar("arg")) and returns a corresponding function stack.
+func NewFunctionStack(in string) (FunctionStack, string, error) {
+ var fs FunctionStack
+
+ getCallback := func(name string) (CallbackFunc, error) {
+ var cb CallbackFunc
+
+ switch name {
+ case "md5sum":
+ return Md5Sum, nil
+ case "maskdigits":
+ return MaskDigits, nil
+ default:
+ return cb, fmt.Errorf("unknown function '%s'", name)
+ }
+ }
+
+ aux := in
+ for strings.HasSuffix(aux, ")") {
+ index := strings.Index(aux, "(")
+ if index <= 0 {
+ return fs, "", fmt.Errorf("unable to parse function '%s' at '%s'", in, aux)
+ }
+ name := aux[0:index]
+
+ call, err := getCallback(name)
+ if err != nil {
+ return fs, "", err
+ }
+ fs = append(fs, Function{name, call})
+ aux = aux[index+1 : len(aux)-1]
+ }
+
+ return fs, aux, nil
+}
+
+// Call the function stack.
+func (fs FunctionStack) Call(str string) string {
+ for i := len(fs) - 1; i >= 0; i-- {
+ //logger.Debug("Call", fs[i].Name, str)
+ str = fs[i].call(str)
+ //logger.Debug("Call.result", fs[i].Name, str)
+ }
+
+ return str
+}
diff --git a/internal/mapr/funcs/function_test.go b/internal/mapr/funcs/function_test.go
new file mode 100644
index 0000000..415683c
--- /dev/null
+++ b/internal/mapr/funcs/function_test.go
@@ -0,0 +1,45 @@
+package funcs
+
+import "testing"
+
+func TestFunction(t *testing.T) {
+ input := "md5sum($line)"
+ fs, arg, err := NewFunctionStack(input)
+ if err != nil {
+ t.Errorf("error parsing function input '%s': %s (%v)\n", input, err.Error(), fs)
+ }
+ if arg != "$line" {
+ t.Errorf("error parsing function input '%s': expected argument '$line' but got '%s' (%v)\n", input, arg, fs)
+ }
+ t.Log(input, fs, arg)
+
+ result := fs.Call(input)
+ if result != "b38699013d79e50d9d122433753959c1" {
+ t.Errorf("error executing function stack '%s': expected result 'b38699013d79e50d9d122433753959c1' but got '%s' (%v)\n", input, result, fs)
+ }
+
+ input = "maskdigits(md5sum(maskdigits($line)))"
+ fs, arg, err = NewFunctionStack(input)
+ if err != nil {
+ t.Errorf("error parsing function input '%s': %s (%v)\n", input, err.Error(), fs)
+ }
+ if arg != "$line" {
+ t.Errorf("error parsing function input '%s': expected argument '$line' but got '%s' (%v)\n", input, arg, fs)
+ }
+ t.Log(input, fs, arg)
+
+ result = fs.Call(input)
+ if result != ".fac.bbe..bb.........d...a.c..b." {
+ t.Errorf("error executing function stack '%s': expected result '.fac.bbe..bb.........d...a.c..b.' but got '%s' (%v)\n", input, result, fs)
+ }
+
+ input = "md5sum$line)"
+ if fs, _, err := NewFunctionStack(input); err == nil {
+ t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", input, fs)
+ }
+
+ input = "md5sum(makedigits$line))"
+ if fs, _, err := NewFunctionStack(input); err == nil {
+ t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", input, fs)
+ }
+}
diff --git a/internal/mapr/funcs/maskdigits.go b/internal/mapr/funcs/maskdigits.go
new file mode 100644
index 0000000..d51f3d8
--- /dev/null
+++ b/internal/mapr/funcs/maskdigits.go
@@ -0,0 +1,14 @@
+package funcs
+
+// MaskDigits masks all digits (replaces them with .)
+func MaskDigits(input string) string {
+ s := []byte(input)
+
+ for i, b := range s {
+ if '0' <= b && b <= '9' {
+ s[i] = '.'
+ }
+ }
+
+ return string(s)
+}
diff --git a/internal/mapr/funcs/md5sum.go b/internal/mapr/funcs/md5sum.go
new file mode 100644
index 0000000..e3cc7e6
--- /dev/null
+++ b/internal/mapr/funcs/md5sum.go
@@ -0,0 +1,12 @@
+package funcs
+
+import (
+ "crypto/md5"
+ "encoding/hex"
+)
+
+// Md5Sum returns the hex encoded MD5 checksum of a given input string.
+func Md5Sum(text string) string {
+ hash := md5.Sum([]byte(text))
+ return hex.EncodeToString(hash[:])
+}
diff --git a/internal/mapr/logformat/default.go b/internal/mapr/logformat/default.go
index 0dfdde0..44bf558 100644
--- a/internal/mapr/logformat/default.go
+++ b/internal/mapr/logformat/default.go
@@ -24,5 +24,6 @@ func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) {
}
fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1]
}
+
return fields, nil
}
diff --git a/internal/mapr/logformat/default_test.go b/internal/mapr/logformat/default_test.go
index a3c47fb..d7a4da4 100644
--- a/internal/mapr/logformat/default_test.go
+++ b/internal/mapr/logformat/default_test.go
@@ -5,7 +5,7 @@ import (
)
func TestDefaultLogFormat(t *testing.T) {
- parser, err := NewParser("default")
+ parser, err := NewParser("default", nil)
if err != nil {
t.Errorf("Unable to create parser: %s", err.Error())
}
diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go
index cc9c268..c53729a 100644
--- a/internal/mapr/logformat/parser.go
+++ b/internal/mapr/logformat/parser.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/mapr"
)
// Parser is used to parse the mapreduce information from the server log files.
@@ -22,7 +23,7 @@ type Parser struct {
}
// NewParser returns a new log parser.
-func NewParser(logFormatName string) (*Parser, error) {
+func NewParser(logFormatName string, query *mapr.Query) (*Parser, error) {
hostname, err := os.Hostname()
if err != nil {
diff --git a/internal/mapr/query.go b/internal/mapr/query.go
index 6dff792..7f6b63c 100644
--- a/internal/mapr/query.go
+++ b/internal/mapr/query.go
@@ -20,6 +20,7 @@ type Query struct {
Select []selectCondition
Table string
Where []whereCondition
+ Set []setCondition
GroupBy []string
OrderBy string
ReverseOrder bool
@@ -29,13 +30,15 @@ type Query struct {
Outfile string
RawQuery string
tokens []token
+ LogFormat string
}
func (q Query) String() string {
- return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,GroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v)",
+ return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,Set:%vGroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v,LogFormat:%s)",
q.Select,
q.Table,
q.Where,
+ q.Set,
q.GroupBy,
q.GroupKey,
q.OrderBy,
@@ -44,7 +47,8 @@ func (q Query) String() string {
q.Limit,
q.Outfile,
q.RawQuery,
- q.tokens)
+ q.tokens,
+ q.LogFormat)
}
// NewQuery returns a new mapreduce query.
@@ -68,10 +72,16 @@ func NewQuery(queryStr string) (*Query, error) {
return &q, err
}
+// HasOutfile returns true if query result will be written to a CVS output file.
func (q *Query) HasOutfile() bool {
return q.Outfile != ""
}
+// Has is a helper to determine whether a query contains a substring
+func (q *Query) Has(what string) bool {
+ return strings.Contains(q.RawQuery, what)
+}
+
func (q *Query) parse(tokens []token) error {
var found []token
var err error
@@ -86,14 +96,23 @@ func (q *Query) parse(tokens []token) error {
}
case "from":
tokens, found = tokensConsume(tokens[1:])
- if len(found) > 0 {
- q.Table = strings.ToUpper(found[0].str)
+ if len(found) == 0 {
+ return errors.New(invalidQuery + "expected table name after 'from'")
+ }
+ if len(found) > 1 {
+ return errors.New(invalidQuery + "expected only one table name after 'from'")
}
+ q.Table = strings.ToUpper(found[0].str)
case "where":
tokens, found = tokensConsume(tokens[1:])
if q.Where, err = makeWhereConditions(found); err != nil {
return err
}
+ case "set":
+ tokens, found = tokensConsume(tokens[1:])
+ if q.Set, err = makeSetConditions(found); err != nil {
+ return err
+ }
case "group":
tokens = tokensConsumeOptional(tokens[1:], "by")
if tokens == nil || len(tokens) < 1 {
@@ -147,6 +166,12 @@ func (q *Query) parse(tokens []token) error {
return errors.New(invalidQuery + unexpectedEnd)
}
q.Outfile = found[0].str
+ case "logformat":
+ tokens, found = tokensConsume(tokens[1:])
+ if len(found) == 0 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ q.LogFormat = found[0].str
default:
return errors.New(invalidQuery + "Unexpected keyword " + tokens[0].str)
}
@@ -181,73 +206,3 @@ func (q *Query) parse(tokens []token) error {
return nil
}
-
-// WhereClause interprets the where clause of the mapreduce query.
-func (q *Query) WhereClause(fields map[string]string) bool {
- floatValue := func(str string, float float64, t whereType) (float64, bool) {
- switch t {
- case Float:
- return float, true
- case Field:
- value, ok := fields[str]
- if !ok {
- return 0, false
- }
- f, err := strconv.ParseFloat(value, 64)
- if err != nil {
- return 0, false
- }
- return f, true
- default:
- logger.Error("Unexpected argument in 'where' clause", str, float, t)
- return 0, false
- }
- }
-
- stringValue := func(str string, t whereType) (string, bool) {
- switch t {
- case Field:
- value, ok := fields[str]
- if !ok {
- return str, false
- }
- return value, true
- case String:
- return str, true
- default:
- logger.Error("Unexpected argument in 'where' clause", str, t)
- return str, false
- }
- }
-
- for _, wc := range q.Where {
- var ok bool
-
- if wc.Operation > FloatOperation {
- var lValue, rValue float64
- if lValue, ok = floatValue(wc.lString, wc.lFloat, wc.lType); !ok {
- return false
- }
- if rValue, ok = floatValue(wc.rString, wc.rFloat, wc.rType); !ok {
- return false
- }
- if ok = wc.floatClause(lValue, rValue); !ok {
- return false
- }
- continue
- }
-
- var lValue, rValue string
- if lValue, ok = stringValue(wc.lString, wc.lType); !ok {
- return false
- }
- if rValue, ok = stringValue(wc.rString, wc.rType); !ok {
- return false
- }
- if ok = wc.stringClause(lValue, rValue); !ok {
- return false
- }
- }
-
- return true
-}
diff --git a/internal/mapr/query_test.go b/internal/mapr/query_test.go
index 6176461..b0b6c3a 100644
--- a/internal/mapr/query_test.go
+++ b/internal/mapr/query_test.go
@@ -8,13 +8,13 @@ import (
func TestParseQuerySimple(t *testing.T) {
errorQueries := []string{
"select",
- "select foo",
"select foo from",
"select foo from bar where baz",
"select foo from bar where baz <",
"select foo from bar where baz < 100 bay eq 12 group",
"select foo from bar where baz < 100 bay eq 12 group by foo order by",
"select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit set foo = bar;",
}
okQueries := []string{"select foo from bar",
"select foo from bar where",
@@ -24,6 +24,7 @@ func TestParseQuerySimple(t *testing.T) {
"select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo",
"select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23",
"select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\"",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\" set $foo = maskdigits(bar), $baz = 12, $bay = $foo;",
}
for _, queryStr := range errorQueries {
@@ -45,13 +46,8 @@ func TestParseQuerySimple(t *testing.T) {
func TestParseQueryDeep(t *testing.T) {
dialects := []string{
- "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23",
- "SELECT s1, `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23",
- "select s1, `from` count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23",
- "sElEct s1, `from` coUnt(s3) from taBle where w1 == 2 aNd w2 eq \"free beer\" Group By g1, g2 order bY count(s3) intervaL 10 LiMiT 23",
- "SELECT s1 `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP BY g1 g2 ORDER BY count(s3) INTERVAL 10 LIMIT 23",
- "select s1 `from` count(s3) from table where w1 == 2 w2 eq \"free beer\" group g1 g2 order count(s3) interval 10 limit 23",
- "limit 23 interval 10 order count(s3) group g1 g2 where w1 == 2 w2 eq \"free beer\" from table select s1 `from` count(s3)",
+ "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23 set $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic",
+ "SELECT s1, `from`, COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23 SET $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic",
}
for _, queryStr := range dialects {
@@ -60,6 +56,8 @@ func TestParseQueryDeep(t *testing.T) {
t.Errorf("%s: %s", err.Error(), queryStr)
}
+ t.Log(q)
+
// 'select' clause
if len(q.Select) != 3 {
t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v", q.Select, queryStr, q)
@@ -145,5 +143,31 @@ func TestParseQueryDeep(t *testing.T) {
if q.Limit != 23 {
t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v", q.Limit, queryStr, q)
}
+
+ // 'set' clause
+ if q.Set[0].lString != "$foo" {
+ t.Errorf("Expected '$foo' lvalue in first 'set' condition clause but got '%v': %s\n%v", q.Set[0].lString, queryStr, q)
+ }
+ if q.Set[0].rString != "bar" {
+ t.Errorf("Expected 'bar' rvalue in first 'set' condition clause but got '%v': %s\n%v", q.Set[0].rString, queryStr, q)
+ }
+
+ if q.Set[1].lString != "$baz" {
+ t.Errorf("Expected '$baz' lvalue in second 'set' condition clause but got '%v': %s\n%v", q.Set[1].lString, queryStr, q)
+ }
+ if q.Set[1].rString != "12" {
+ t.Errorf("Expected '12' rvalue in second 'set' condition clause but got '%v': %s\n%v", q.Set[1].rString, queryStr, q)
+ }
+
+ if q.Set[2].lString != "$bay" {
+ t.Errorf("Expected '$bay' lvalue in third 'set' condition clause but got '%v': %s\n%v", q.Set[2].lString, queryStr, q)
+ }
+ if q.Set[2].rString != "$foo" {
+ t.Errorf("Expected '$foo' rvalue in third 'set' condition clause but got '%v': %s\n%v", q.Set[2].rString, queryStr, q)
+ }
+
+ if q.LogFormat != "generic" {
+ t.Errorf("Expected 'generic' logformat got '%v': %s\n%v", q.LogFormat, queryStr, q)
+ }
}
}
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
index 80a464d..1028943 100644
--- a/internal/mapr/server/aggregate.go
+++ b/internal/mapr/server/aggregate.go
@@ -44,15 +44,24 @@ func NewAggregate(queryStr string) (*Aggregate, error) {
}
s := strings.Split(fqdn, ".")
- parserName := config.Server.MapreduceLogFormat
- if query.Table == "" {
- parserName = "generic"
+ var parserName string
+ switch query.LogFormat {
+ case "":
+ parserName = config.Server.MapreduceLogFormat
+ if query.Table == "" {
+ parserName = "generic"
+ }
+ default:
+ parserName = query.LogFormat
}
- logger.Info("Creating mapr log format parser", parserName)
- logParser, err := logformat.NewParser(parserName)
+ logger.Info("Creating log format parser", parserName)
+ logParser, err := logformat.NewParser(parserName, query)
if err != nil {
- logger.FatalExit("Could not create mapr log format parser", err)
+ logger.Error("Could not create log format parser. Falling back to 'generic'", err)
+ if logParser, err = logformat.NewParser("generic", query); err != nil {
+ logger.FatalExit("Could not create log format parser", err)
+ }
}
ctx, cancel := context.WithCancel(context.Background())
@@ -76,6 +85,12 @@ func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) {
defer a.cancel()
fieldsCh := a.linesToFields(ctx)
+
+ // Add fields (e.g. via 'set' clause)
+ if len(a.query.Set) > 0 {
+ fieldsCh = a.addMoreFields(ctx, fieldsCh)
+ }
+
go a.fieldsToMaprLines(ctx, fieldsCh, maprLines)
a.periodicAggregateTimer(ctx)
}
@@ -99,10 +114,10 @@ func (a *Aggregate) periodicAggregateTimer(ctx context.Context) {
}
func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string {
- fieldsCh := make(chan map[string]string)
+ ch := make(chan map[string]string)
go func() {
- defer close(fieldsCh)
+ defer close(ch)
for {
select {
@@ -113,6 +128,7 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
maprLine := strings.TrimSpace(string(line.Content))
fields, err := a.parser.MakeFields(maprLine)
+ logger.Debug(fields, err)
if err != nil {
logger.Error(err)
@@ -123,7 +139,7 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
}
select {
- case fieldsCh <- fields:
+ case ch <- fields:
case <-ctx.Done():
}
case <-ctx.Done():
@@ -134,7 +150,33 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
}
}()
- return fieldsCh
+ return ch
+}
+
+func (a *Aggregate) addMoreFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string {
+ ch := make(chan map[string]string)
+
+ go func() {
+ defer close(ch)
+
+ for {
+ // fieldsCh will be closed via 'linesToFields' if ctx is done
+ fields, ok := <-fieldsCh
+ if !ok {
+ return
+ }
+ if err := a.query.SetClause(fields); err != nil {
+ logger.Error(err)
+ }
+
+ select {
+ case ch <- fields:
+ case <-ctx.Done():
+ }
+ }
+ }()
+
+ return ch
}
func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
@@ -192,12 +234,6 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
var addedSample bool
for _, sc := range a.query.Select {
if val, ok := fields[sc.Field]; ok {
- /*
- if sc.Field == "$line" {
- // Complete log line as to arrive untouched on the client side.
- val = base64.StdEncoding.EncodeToString([]byte(val))
- }
- */
if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, false); err != nil {
logger.Error(err)
continue
diff --git a/internal/mapr/setclause.go b/internal/mapr/setclause.go
new file mode 100644
index 0000000..b4c2f73
--- /dev/null
+++ b/internal/mapr/setclause.go
@@ -0,0 +1,20 @@
+package mapr
+
+// SetClause interprets the set clause of the mapreduce query.
+func (q *Query) SetClause(fields map[string]string) error {
+ for _, sc := range q.Set {
+ value, ok := fields[sc.rString]
+ if !ok {
+ continue
+ }
+
+ switch sc.rType {
+ case FunctionStack:
+ fields[sc.lString] = sc.functionStack.Call(value)
+ default:
+ fields[sc.lString] = value
+ }
+ }
+
+ return nil
+}
diff --git a/internal/mapr/setcondition.go b/internal/mapr/setcondition.go
new file mode 100644
index 0000000..8c5cfc9
--- /dev/null
+++ b/internal/mapr/setcondition.go
@@ -0,0 +1,93 @@
+package mapr
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/mapr/funcs"
+)
+
+// Represent a parsed "set" clause, used by mapr.Query
+type setCondition struct {
+ lString string
+
+ rType fieldType
+ rString string
+ rFloat float64
+
+ // For now only text functions are supported.
+ // Maybe in the future we can have typed functions too
+ // so that a float input/output is possible.
+ functionStack funcs.FunctionStack
+}
+
+func (sc *setCondition) String() string {
+ return fmt.Sprintf("setCondition(lString:%s,rString:%s,rType:%s,functionStack:%v)",
+ sc.lString, sc.rString, sc.rType.String(), sc.functionStack)
+}
+
+func makeSetConditions(tokens []token) (set []setCondition, err error) {
+ parse := func(tokens []token) (setCondition, []token, error) {
+ var sc setCondition
+ if len(tokens) < 3 {
+ return sc, nil, errors.New(invalidQuery + "Not enough arguments in 'set' clause")
+ }
+
+ setOp := strings.ToLower(tokens[1].str)
+ switch setOp {
+ case "=":
+ default:
+ return sc, nil, errors.New(invalidQuery + "Unknown operation in 'set' clause: " + setOp)
+ }
+
+ if !tokens[0].isBareword {
+ return sc, nil, errors.New(invalidQuery + "Expected bareword at 'set' clause's lValue: " + tokens[0].str)
+ }
+
+ sc.lString = tokens[0].str
+ if !strings.HasPrefix(sc.lString, "$") {
+ return sc, nil, errors.New(invalidQuery + "Expected field variable name (starting with $) at 'set' clause's lValue: " + tokens[0].str)
+ }
+ sc.rType = Field
+
+ rString := tokens[2].str
+ // Seems like a function call?
+ if strings.HasSuffix(rString, ")") {
+ functionStack, functionArg, err := funcs.NewFunctionStack(tokens[2].str)
+ if err != nil {
+ return sc, nil, err
+ }
+ sc.functionStack = functionStack
+ sc.rType = FunctionStack
+ sc.rString = functionArg
+ return sc, tokens[3:], nil
+ }
+
+ sc.rString = rString
+ if f, err := strconv.ParseFloat(sc.rString, 64); err == nil {
+ sc.rFloat = f
+ sc.rType = Float
+ } else {
+ sc.rType = Field
+ }
+
+ return sc, tokens[3:], nil
+ }
+
+ for len(tokens) > 0 {
+ var sc setCondition
+ var err error
+
+ sc, tokens, err = parse(tokens)
+ if err != nil {
+ return nil, err
+ }
+
+ set = append(set, sc)
+ tokens = tokensConsumeOptional(tokens, ",")
+ }
+
+ return
+}
diff --git a/internal/mapr/token.go b/internal/mapr/token.go
index b8be4da..d337bd2 100644
--- a/internal/mapr/token.go
+++ b/internal/mapr/token.go
@@ -4,7 +4,7 @@ import (
"strings"
)
-var keywords = [...]string{"select", "from", "where", "group", "rorder", "order", "interval", "limit", "outfile"}
+var keywords = [...]string{"select", "from", "where", "set", "group", "rorder", "order", "interval", "limit", "outfile", "logformat"}
// Represents a parsed token, used to parse the mapr query.
type token struct {
diff --git a/internal/mapr/whereclause.go b/internal/mapr/whereclause.go
new file mode 100644
index 0000000..cc1c164
--- /dev/null
+++ b/internal/mapr/whereclause.go
@@ -0,0 +1,77 @@
+package mapr
+
+import (
+ "strconv"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+)
+
+// WhereClause interprets the where clause of the mapreduce query.
+func (q *Query) WhereClause(fields map[string]string) bool {
+ for _, wc := range q.Where {
+ var ok bool
+
+ if wc.Operation > FloatOperation {
+ var lValue, rValue float64
+ if lValue, ok = whereClauseFloatValue(fields, wc.lString, wc.lFloat, wc.lType); !ok {
+ return false
+ }
+ if rValue, ok = whereClauseFloatValue(fields, wc.rString, wc.rFloat, wc.rType); !ok {
+ return false
+ }
+ if ok = wc.floatClause(lValue, rValue); !ok {
+ return false
+ }
+ continue
+ }
+
+ var lValue, rValue string
+ if lValue, ok = whereClauseStringValue(fields, wc.lString, wc.lType); !ok {
+ return false
+ }
+ if rValue, ok = whereClauseStringValue(fields, wc.rString, wc.rType); !ok {
+ return false
+ }
+ if ok = wc.stringClause(lValue, rValue); !ok {
+ return false
+ }
+ }
+
+ return true
+}
+
+func whereClauseFloatValue(fields map[string]string, str string, float float64, t fieldType) (float64, bool) {
+ switch t {
+ case Float:
+ return float, true
+ case Field:
+ value, ok := fields[str]
+ if !ok {
+ return 0, false
+ }
+ f, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return 0, false
+ }
+ return f, true
+ default:
+ logger.Error("Unexpected argument in 'where' clause", str, float, t)
+ return 0, false
+ }
+}
+
+func whereClauseStringValue(fields map[string]string, str string, t fieldType) (string, bool) {
+ switch t {
+ case Field:
+ value, ok := fields[str]
+ if !ok {
+ return str, false
+ }
+ return value, true
+ case String:
+ return str, true
+ default:
+ logger.Error("Unexpected argument in 'where' clause", str, t)
+ return str, false
+ }
+}
diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go
index 3ca9103..ff1b489 100644
--- a/internal/mapr/wherecondition.go
+++ b/internal/mapr/wherecondition.go
@@ -28,40 +28,17 @@ const (
FloatGe QueryOperation = iota
)
-type whereType int
-
-// The possible field types.
-const (
- UndefWhereType whereType = iota
- Field whereType = iota
- String whereType = iota
- Float whereType = iota
-)
-
-func (w whereType) String() string {
- switch w {
- case Field:
- return fmt.Sprintf("Field")
- case String:
- return fmt.Sprintf("String")
- case Float:
- return fmt.Sprintf("Float")
- default:
- return fmt.Sprintf("UndefWhereType")
- }
-}
-
// Represent a parsed "where" clause, used by mapr.Query
type whereCondition struct {
+ lType fieldType
lString string
lFloat float64
- lType whereType
Operation QueryOperation
+ rType fieldType
rString string
rFloat float64
- rType whereType
}
func (wc *whereCondition) String() string {