summaryrefslogtreecommitdiff
path: root/internal/io/run/run.go
blob: 4d57f9fe40bf13830e4879ad8815a8f00b7aa652 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package run

import (
	"bufio"
	"context"
	"io"
	"os/exec"
	"sync"
	"syscall"
	"time"

	"github.com/mimecast/dtail/internal/io/line"
	"github.com/mimecast/dtail/internal/io/logger"
)

// Run is for execute a command.
type Run struct {
	command      string
	args         []string
	cmd          *exec.Cmd
	pgroupKilled chan struct{}
}

// New returns a new command runner.
func New(command string, args []string) Run {
	return Run{
		command:      command,
		args:         args,
		pgroupKilled: make(chan struct{}),
	}
}

// StartBackground starts running the command in background.
func (r Run) StartBackground(ctx context.Context, wg *sync.WaitGroup, ec chan<- int, lines chan<- line.Line) (pid int, err error) {
	pid = -1

	if len(r.args) > 0 {
		logger.Debug(r.command, r.args, " ")
		r.cmd = exec.CommandContext(ctx, r.command, r.args...)
	} else {
		logger.Debug(r.command)
		r.cmd = exec.CommandContext(ctx, r.command)
	}

	// Create a new process group, so that kill() will only kill this command + pgroup.
	r.cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}

	stdoutPipe, myErr := r.cmd.StdoutPipe()
	if err != nil {
		wg.Done()
		err = myErr
		return
	}

	stderrPipe, myErr := r.cmd.StderrPipe()
	if myErr != nil {
		wg.Done()
		err = myErr
		return
	}

	if myErr := r.cmd.Start(); err != nil {
		wg.Done()
		err = myErr
		return
	}

	if r.cmd.Process != nil {
		pid = r.cmd.Process.Pid
	}

	commandExited := make(chan struct{})

	var pipeWg sync.WaitGroup
	pipeWg.Add(2)

	go r.killPgroup(ctx, commandExited, pid)
	go r.pipeToLines(commandExited, &pipeWg, pid, stdoutPipe, "STDOUT", lines)
	go r.pipeToLines(commandExited, &pipeWg, pid, stderrPipe, "STDERR", lines)

	go func() {
		exitCode := 255
		if waitErr := r.cmd.Wait(); waitErr != nil {
			if exitError, ok := waitErr.(*exec.ExitError); ok {
				exitCode = exitError.ExitCode()
			}
		}
		ec <- exitCode

		// Tell pipes we are done
		close(commandExited)
		// Wait for process group to be killed
		<-r.pgroupKilled
		// Wait for the pipes to flush the contents
		pipeWg.Wait()
		// Now the job is truly done
		wg.Done()
	}()

	return
}

func (r Run) pipeToLines(commandExited chan struct{}, wg *sync.WaitGroup, pid int, reader io.Reader, what string, lines chan<- line.Line) {
	defer wg.Done()
	bufReader := bufio.NewReader(reader)

	for {
		time.Sleep(time.Millisecond * 10)
		lineStr, err := bufReader.ReadString('\n')

		if err != nil {
			select {
			case <-commandExited:
				return
			}
			continue
		}

		newLine := line.Line{
			Content:         []byte(lineStr),
			Count:           uint64(pid),
			TransmittedPerc: 100,
			SourceID:        what,
		}

		select {
		case lines <- newLine:
		case <-commandExited:
			return
		}
	}
}

func (r Run) killPgroup(ctx context.Context, commandExited chan struct{}, pid int) {
	if pid == -1 {
		close(r.pgroupKilled)
		return
	}

	if pgid, err := syscall.Getpgid(pid); err == nil {
		// Kill process group when done
		select {
		case <-ctx.Done():
		case <-commandExited:
		}
		syscall.Kill(-pgid, syscall.SIGKILL)
		close(r.pgroupKilled)
	}
}