diff options
87 files changed, 7746 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..79c20cf --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*_proprietary.go +cache/ +log/ diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..fa046e4 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,81 @@ +Code of Conduct +=============== + +Our Pledge +---------- +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of experience, +nationality, personal appearance, race, religion, or sexual identity and +orientation. + + +Our Standards +------------- +Examples of behaviour that contributes to creating a positive environment +include: + +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members + +Examples of unacceptable behaviour by participants include: + +- The use of sexualized language or imagery and unwelcome sexual attention or +advances +- Trolling, insulting/derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or electronic + address, without explicit permission +- Other conduct which could reasonably be considered inappropriate in a + professional setting + + +Our Responsibilities +-------------------- +Project maintainers are responsible for clarifying the standards of acceptable +behaviour and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behaviour. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviours that they deem inappropriate, +threatening, offensive, or harmful. + + +Scope +----- +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + + +Enforcement +----------- +Instances of abusive, harassing, or otherwise unacceptable behaviour may be +reported by contacting the project team on our [mailing list][mailinglist]. +All complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + + +Attribution +----------- +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ +[mailinglist]: mailto:opensource@mimecast.com diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..9ee852f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,79 @@ +Contributing +============ +Contributions of any kind (bug fixes, new features...) are welcome! +This is a development tool and as such it may not be perfect and may be lacking in some areas. + +Certain future functionalities are marked with TODO comments throughout the code. +This however does not mean they will be given priority or ever be done. + + +Reporting bugs +-------------- +- Ensure the bug was not already reported by searching on GitHub under +[Issues][githubissues]. + +- If you're unable to find an open issue addressing the problem, +[open a new one][githubnewissue]. Be sure to include a **title and clear description**, +as much relevant information as possible, and a **code sample** or an **executable test case** +demonstrating the expected behaviour that is not occurring. + + +Writing a patch +--------------- +- Open a new GitHub pull request with the patch. + +- Ensure the PR description clearly describes the problem and solution. +Include the relevant issue number if applicable. + +- Before submitting a merge request please run a comprehensive code quality analysis + +- When you feel that a certain code quality rule is not applicable, make sure to limit your +warning suppression is as strict as possible to not supress other rules that should apply. + +- Please ensure your merge request aligns to existing coding style and naming conventions for consistency. + + +Cosmetic changes +---------------- +- Changes that are cosmetic in nature and do not add anything substantial to the stability, +functionality, or testability will generally not be accepted. + + +New features +------------ + +- Suggest your change(s) to our [mailing list][mailinglist] before writing code. +This will allow us to ensure we do not have a race condition with other contributors. + +- Do not open an issue on GitHub until you have collected positive feedback about the change. +GitHub issues are primarily intended for bug reports and fixes. + + +Questions +--------- + +- Email any question to our [mailing list][mailinglist]. +We will endeavour to answer, but please excuse us if we don't. +The support for this project is dependent on the availability of spare time for our staff. + + +Documentation +------------- + +- DTail's code is documented to a large extent and additional usage documentation is provided +in this project's [doc/](doc/) directory. + +- If you feel that certain areas are lacking and wish to contribute please follow the +**writing a patch** instructions. + + +Thank you +--------- + +Thank you for showing interest in DTail! + +Mimecast Team + +[githubissues]: https://github.com/mimecast/dtail/issues +[githubnewissue]: https://github.com/mimecast/dtail/issues/new +[mailinglist]: mailto:opensource@mimecast.com @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.
\ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3d24800 --- /dev/null +++ b/Makefile @@ -0,0 +1,28 @@ +GO ?= go +all: build +build: + ${GO} version + ${GO} build + cp -pv ./dtail ./dcat + cp -pv ./dtail ./dgrep + cp -pv ./dtail ./dmap + cp -pv ./dtail ./dserver +clean: + rm -v dtail dgrep dcat dmap dserver 2>/dev/null +install: + ${GO} install + cp -pv ${GOPATH}/bin/dtail ${GOPATH}/bin/dcat + cp -pv ${GOPATH}/bin/dtail ${GOPATH}/bin/dgrep + cp -pv ${GOPATH}/bin/dtail ${GOPATH}/bin/dmap + cp -pv ${GOPATH}/bin/dtail ${GOPATH}/bin/dserver +vet: + find . -type d | while read dir; do \ + echo ${GO} vet $$dir; \ + ${GO} vet $$dir; \ + done +lint: + ${GO} get golang.org/x/lint/golint + find . -type d | while read dir; do \ + echo ${GOPATH}/bin/golint $$dir; \ + ${GOPATH}/bin/golint $$dir; \ + done diff --git a/README.md b/README.md new file mode 100644 index 0000000..5dc93cf --- /dev/null +++ b/README.md @@ -0,0 +1,35 @@ +DTail +===== + + + +DTail (a distributed tail program) is a DevOps tool for engineers programmed in Google Go for following (tailing), catting and grepping (including gzip and zstd decompression support) log files on many machines concurrently. An advanced feature of DTail is to execute distributed mapreduce aggregations across many machines. + +For secure authorization and transport encryption the SSH protocol is used. Furthermore, DTail respects the UNIX file system permission model (traditional on all Linux/UNIX variants and also ACLs on Linux based operating systems). + +The DTail binary operate in either client or in server mode. The DTail server must be installed on all server boxes involved. The DTail client (possibly running on a regular Laptop) is used interactively by the user to connect to the servers concurrently. That currently scales to multiple thousands of servers per client. + + + +If you like what you see [look here for more examples](doc/examples.md)! + +Installation and Usage +====================== + +* For simplest setup please follow the [Quick Starting Guide](doc/quickstart.md). +* For a more sustainable setup please follow the [Installation Guide](doc/installation.md). +* Please also have a look at the [Usage Examples](doc/examples.md). + +More +==== + +* [How to contribute](CONTRIBUTING.md) +* [Code of conduct](CODE_OF_CONDUCT.md) +* [License](LICENSE) + +Credits +======= + +* DTail was created by Paul Buetow. + +* Thank you to Vlad-Marian Marian for creating the DTail logo. diff --git a/clients/args.go b/clients/args.go new file mode 100644 index 0000000..4d5a029 --- /dev/null +++ b/clients/args.go @@ -0,0 +1,26 @@ +package clients + +import ( + "dtail/omode" +) + +// Args is a helper struct to summarize common client arguments. +type Args struct { + // The operating mode (tail, grep, ...) + Mode omode.Mode + // The raw server string + ServersStr string + // SSH user name (e.g. 'pbuetow') + UserName string + // The files to follow. + Files string + // Regex for filtering. + Regex string + // Trust all unknown host keys? + TrustAllHosts bool + // Server discovery method + Discovery string + MaxInitConnections int + // Server ping timeout (0 means pings disabled) + PingTimeout int +} diff --git a/clients/baseclient.go b/clients/baseclient.go new file mode 100644 index 0000000..3a1b8f0 --- /dev/null +++ b/clients/baseclient.go @@ -0,0 +1,139 @@ +package clients + +import ( + "dtail/clients/remote" + "dtail/discovery" + "dtail/logger" + "dtail/omode" + "dtail/ssh/client" + "regexp" + "sync" + "time" + + gossh "golang.org/x/crypto/ssh" +) + +// This is the main client data structure. +type baseClient struct { + Args + // To display client side stats + stats *stats + // List of remote servers to connect to. + servers []string + // We have one connection per remote server. + connections []*remote.Connection + // SSH auth methods to use to connect to the remote servers. + sshAuthMethods []gossh.AuthMethod + // To deal with SSH host keys + hostKeyCallback *client.HostKeyCallback + // To stop the client. + stop chan struct{} + // To indicate that the client has stopped. + stopped chan struct{} + // Throttle how fast we initiate SSH connections concurrently + throttleCh chan struct{} + // Retry connection upon failure? + retry bool + // Connection helper. + maker connectionMaker +} + +func (c *baseClient) init(maker connectionMaker) { + logger.Info("Initiating base client") + + c.maker = maker + //c.connections = make(map[string]*remote.Connection) + c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh) + + // Retrieve a shuffled list of remote dtail servers. + shuffleServers := true + discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers) + for _, server := range discoveryService.ServerList() { + c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + } + + if _, err := regexp.Compile(c.Regex); err != nil { + logger.FatalExit(c.Regex, "Can't test compile regex", err) + } + + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(c.stop) + + // Periodically print out connection stats to the client. + c.stats = newTailStats(len(c.connections)) + go c.stats.periodicLogStats(c.throttleCh, c.stop) +} + +func (c *baseClient) Start(wg *sync.WaitGroup) (status int) { + if wg != nil { + defer wg.Done() + } + active := make(chan struct{}, len(c.connections)) + + var wg2 sync.WaitGroup + wg2.Add(len(c.connections)) + + for i, conn := range c.connections { + go func(i int, conn *remote.Connection) { + active <- struct{}{} + defer func() { + logger.Debug(conn.Server, "Disconnected completely...") + <-active + }() + wg2.Done() + + for { + conn.Start(c.throttleCh, c.stats.connectionsEstCh) + if !c.retry { + return + } + time.Sleep(time.Second * 2) + logger.Debug(conn.Server, "Reconencting") + conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + c.connections[i] = conn + } + }(i, conn) + } + + wg2.Wait() + c.waitUntilDone(active) + + return +} + +func (c *baseClient) waitUntilDone(active chan struct{}) { + defer close(c.stopped) + + if c.Mode != omode.TailClient { + c.waitUntilZero(active) + logger.Info("All connections stopped") + return + } + + <-c.stop + logger.Info("Stopping client") + for _, conn := range c.connections { + conn.Stop() + } + + c.waitUntilZero(active) +} + +func (c *baseClient) waitUntilZero(active chan struct{}) { + for { + logger.Debug("Active connections", len(active)) + if len(active) == 0 { + return + } + time.Sleep(time.Second) + } +} + +func (c *baseClient) Stop() { + close(c.stop) + <-c.WaitC() +} + +func (c *baseClient) WaitC() <-chan struct{} { + return c.stopped +} diff --git a/clients/catclient.go b/clients/catclient.go new file mode 100644 index 0000000..e3b873c --- /dev/null +++ b/clients/catclient.go @@ -0,0 +1,49 @@ +package clients + +import ( + "dtail/clients/handlers" + "dtail/clients/remote" + "dtail/ssh/client" + "errors" + "fmt" + "strings" + + gossh "golang.org/x/crypto/ssh" +) + +// CatClient is a client for returning a whole file from the beginning to the end. +type CatClient struct { + baseClient +} + +// NewCatClient returns a new cat client. +func NewCatClient(args Args) (*CatClient, error) { + if args.Regex != "" { + return nil, errors.New("Can't use regex with 'cat' operating mode") + } + + args.Regex = "." + + c := CatClient{ + baseClient: baseClient{ + Args: args, + stop: make(chan struct{}), + stopped: make(chan struct{}), + throttleCh: make(chan struct{}, args.MaxInitConnections), + retry: false, + }, + } + + c.init(c) + + return &c, nil +} + +func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) + for _, file := range strings.Split(c.Files, ",") { + conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) + } + return conn +} diff --git a/clients/client.go b/clients/client.go new file mode 100644 index 0000000..e58f51d --- /dev/null +++ b/clients/client.go @@ -0,0 +1,9 @@ +package clients + +import "sync" + +// Client is the interface for the end user command line client. +type Client interface { + Start(wg *sync.WaitGroup) int + Stop() +} diff --git a/clients/connectionmaker.go b/clients/connectionmaker.go new file mode 100644 index 0000000..9e08c2b --- /dev/null +++ b/clients/connectionmaker.go @@ -0,0 +1,12 @@ +package clients + +import ( + "dtail/clients/remote" + "dtail/ssh/client" + + gossh "golang.org/x/crypto/ssh" +) + +type connectionMaker interface { + makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection +} diff --git a/clients/grepclient.go b/clients/grepclient.go new file mode 100644 index 0000000..dbae96c --- /dev/null +++ b/clients/grepclient.go @@ -0,0 +1,49 @@ +package clients + +import ( + "dtail/clients/handlers" + "dtail/clients/remote" + "dtail/ssh/client" + "errors" + "fmt" + "strings" + + gossh "golang.org/x/crypto/ssh" +) + +// GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed. +type GrepClient struct { + baseClient +} + +// NewGrepClient creates a new grep client. +func NewGrepClient(args Args) (*GrepClient, error) { + if args.Regex == "" { + return nil, errors.New("No regex specified, use '-regex' flag") + } + + c := GrepClient{ + baseClient: baseClient{ + Args: args, + stop: make(chan struct{}), + stopped: make(chan struct{}), + throttleCh: make(chan struct{}, args.MaxInitConnections), + retry: false, + }, + } + + c.init(c) + + return &c, nil +} + +func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) + + for _, file := range strings.Split(c.Files, ",") { + conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) + } + + return conn +} diff --git a/clients/handlers/basehandler.go b/clients/handlers/basehandler.go new file mode 100644 index 0000000..ce82aa2 --- /dev/null +++ b/clients/handlers/basehandler.go @@ -0,0 +1,134 @@ +package handlers + +import ( + "dtail/logger" + "errors" + "fmt" + "io" + "strings" + "time" +) + +type baseHandler struct { + server string + shellStarted bool + commands chan string + pong chan struct{} + receiveBuf []byte + stop chan struct{} + pingTimeout int +} + +func (h *baseHandler) Server() string { + return h.server +} + +// Used to determine whether server is still responding to requests or not. +func (h *baseHandler) Ping() error { + if h.pingTimeout == 0 { + // Server ping disabled + return nil + } + + if err := h.SendCommand("ping"); err != nil { + return err + } + + select { + case <-h.pong: + return nil + case <-time.After(time.Duration(h.pingTimeout) * time.Second): + } + + return errors.New("Didn't receive any server pongs (ping replies)") +} + +func (h *baseHandler) SendCommand(command string) error { + if command == "ping" { + logger.Trace("Sending command", h.server, command) + } else { + logger.Debug("Sending command", h.server, command) + } + + select { + case h.commands <- fmt.Sprintf("%s;", command): + case <-time.After(time.Second * 5): + return errors.New("Timed out sending command " + command) + case <-h.stop: + } + + return nil +} + +// Read data from the dtail server via Writer interface. +func (h *baseHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + h.receiveBuf = append(h.receiveBuf, b) + if b == '\n' { + if len(h.receiveBuf) == 0 { + continue + } + message := string(h.receiveBuf) + h.handleMessageType(message) + } + } + + return len(p), nil +} + +// Send data to the dtail server via Reader interface. +func (h *baseHandler) Read(p []byte) (n int, err error) { + select { + case command := <-h.commands: + n = copy(p, []byte(command)) + case <-h.stop: + return 0, io.EOF + } + return +} + +// Handle various message types. +func (h *baseHandler) handleMessageType(message string) { + if len(h.receiveBuf) == 0 { + return + } + // Hidden server commands starti with a dot "." + if h.receiveBuf[0] == '.' { + h.handleHiddenMessage(message) + h.receiveBuf = h.receiveBuf[:0] + return + } + + // Silent mode will only print out remote logs but not remote server + // commands. But remote server commands will be still logged to ./log/. + if logger.Mode == logger.SilentMode { + if h.receiveBuf[0] == 'R' { + logger.Raw(message) + } + h.receiveBuf = h.receiveBuf[:0] + return + } + logger.Raw(message) + h.receiveBuf = h.receiveBuf[:0] +} + +// Handle messages received from server which are not meant to be displayed +// to the end user. +func (h *baseHandler) handleHiddenMessage(message string) { + switch { + case strings.HasPrefix(message, ".pong"): + h.pong <- struct{}{} + case strings.HasPrefix(message, ".syn close connection"): + h.SendCommand("ack close connection") + } +} + +// Stop the handler. +func (h *baseHandler) Stop() { + select { + case <-h.stop: + default: + logger.Debug("Stopping base handler", h.server) + close(h.stop) + } +} diff --git a/clients/handlers/clienthandler.go b/clients/handlers/clienthandler.go new file mode 100644 index 0000000..e818b52 --- /dev/null +++ b/clients/handlers/clienthandler.go @@ -0,0 +1,26 @@ +package handlers + +import ( + "dtail/logger" +) + +// ClientHandler is the basic client handler interface. +type ClientHandler struct { + baseHandler +} + +// NewClientHandler creates a new client handler. +func NewClientHandler(server string, pingTimeout int) *ClientHandler { + logger.Debug(server, "Creating new client handler") + + return &ClientHandler{ + baseHandler{ + server: server, + shellStarted: false, + commands: make(chan string), + pong: make(chan struct{}, 1), + stop: make(chan struct{}), + pingTimeout: pingTimeout, + }, + } +} diff --git a/clients/handlers/handler.go b/clients/handlers/handler.go new file mode 100644 index 0000000..2013be0 --- /dev/null +++ b/clients/handlers/handler.go @@ -0,0 +1,12 @@ +package handlers + +import "io" + +// Handler provides all methods which can be run on any client handler. +type Handler interface { + io.ReadWriter + Ping() error + Stop() + SendCommand(command string) error + Server() string +} diff --git a/clients/handlers/healthhandler.go b/clients/handlers/healthhandler.go new file mode 100644 index 0000000..4051e2c --- /dev/null +++ b/clients/handlers/healthhandler.go @@ -0,0 +1,75 @@ +package handlers + +import ( + "errors" + "fmt" + "time" +) + +// HealthHandler implements the handler required for health checks. +type HealthHandler struct { + // Buffer of incoming data from server. + receiveBuf []byte + // To send commands to the server. + commands chan string + // To receive messages from the server. + receive chan<- string + // The remote server address + server string +} + +// NewHealthHandler returns a new health check handler. +func NewHealthHandler(server string, receive chan<- string) *HealthHandler { + h := HealthHandler{ + server: server, + receive: receive, + commands: make(chan string), + } + + return &h +} + +// Server returns the remote server name. +func (h *HealthHandler) Server() string { + return h.server +} + +// Stop is not of use for health check handler. +func (h *HealthHandler) Stop() { + // Nothing done here. +} + +// Ping is not of use for health check handler. +func (h *HealthHandler) Ping() error { + return nil +} + +// SendCommand send a DTail command to the server. +func (h *HealthHandler) SendCommand(command string) error { + select { + case h.commands <- fmt.Sprintf("%s;", command): + case <-time.NewTimer(time.Second * 10).C: + return errors.New("Timed out sending command " + command) + } + + return nil +} + +// Server writes byte stream to client. +func (h *HealthHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + h.receiveBuf = append(h.receiveBuf, b) + if b == '\n' { + h.receive <- string(h.receiveBuf) + h.receiveBuf = h.receiveBuf[:0] + } + } + + return len(p), nil +} + +// Server reads byte stream from client. +func (h *HealthHandler) Read(p []byte) (n int, err error) { + n = copy(p, []byte(<-h.commands)) + return +} diff --git a/clients/handlers/maprhandler.go b/clients/handlers/maprhandler.go new file mode 100644 index 0000000..830a142 --- /dev/null +++ b/clients/handlers/maprhandler.go @@ -0,0 +1,74 @@ +package handlers + +import ( + "dtail/logger" + "dtail/mapr" + "dtail/mapr/client" + "strings" +) + +// MaprHandler is the handler used on the client side for running mapreduce aggregations. +type MaprHandler struct { + baseHandler + aggregate *client.Aggregate + query *mapr.Query + count uint64 +} + +// NewMaprHandler returns a new mapreduce client handler. +func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler { + return &MaprHandler{ + baseHandler: baseHandler{ + server: server, + shellStarted: false, + commands: make(chan string), + pong: make(chan struct{}, 1), + stop: make(chan struct{}), + pingTimeout: pingTimeout, + }, + query: query, + aggregate: client.NewAggregate(server, query, globalGroup), + } +} + +// Read data from the dtail server via Writer interface. +func (h *MaprHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + h.baseHandler.receiveBuf = append(h.baseHandler.receiveBuf, b) + if b == '\n' { + if len(h.baseHandler.receiveBuf) == 0 { + continue + } + message := string(h.baseHandler.receiveBuf) + + if h.baseHandler.receiveBuf[0] == 'A' { + h.handleAggregateMessage(strings.TrimSpace(message)) + h.baseHandler.receiveBuf = h.baseHandler.receiveBuf[:0] + continue + } + h.baseHandler.handleMessageType(message) + } + } + + return len(p), nil +} + +// Handle a message received from server including mapr aggregation +// related data. +func (h *MaprHandler) handleAggregateMessage(message string) { + h.count++ + parts := strings.Split(message, "|") + + // Index 0 contains 'AGGREGATE', 1 contains server host. + // Aggregation data begins from index 2. + logger.Debug("Received aggregate data", h.server, h.count) + h.aggregate.Aggregate(parts[2:]) + logger.Debug("Aggregated aggregate data", h.server, h.count) +} + +// Stop stops the mapreduce client handler. +func (h *MaprHandler) Stop() { + logger.Debug("Stopping mapreduce handler", h.server) + h.aggregate.Stop() + h.baseHandler.Stop() +} diff --git a/clients/healthclient.go b/clients/healthclient.go new file mode 100644 index 0000000..1fae99c --- /dev/null +++ b/clients/healthclient.go @@ -0,0 +1,96 @@ +package clients + +import ( + "dtail/clients/handlers" + "dtail/clients/remote" + "dtail/config" + "dtail/omode" + "fmt" + "runtime" + "strings" + "sync" + "time" + + gossh "golang.org/x/crypto/ssh" +) + +// HealthClient is used for health checking (e.g. via Nagios) +type HealthClient struct { + // Client operating mode + mode omode.Mode + // The remote server address + server string + // SSH user name + userName string + // SSH auth methods to use to connect to the remote servers. + sshAuthMethods []gossh.AuthMethod +} + +// NewHealthClient returns a new healh client. +func NewHealthClient(mode omode.Mode) (*HealthClient, error) { + c := HealthClient{ + mode: mode, + server: fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort), + userName: config.ControlUser, + } + c.initSSHAuthMethods() + + return &c, nil +} + +// Start the health client. +func (c *HealthClient) Start(wg *sync.WaitGroup) (status int) { + defer wg.Done() + receive := make(chan string) + + throttleCh := make(chan struct{}, runtime.NumCPU()) + statsCh := make(chan struct{}, 1) + + conn := remote.NewOneOffConnection(c.server, c.userName, c.sshAuthMethods) + conn.Handler = handlers.NewHealthHandler(c.server, receive) + conn.Commands = []string{c.mode.String()} + + go conn.Start(throttleCh, statsCh) + defer conn.Stop() + + for { + select { + case data := <-receive: + // Parse recieved data. + s := strings.Split(data, "|") + message := s[len(s)-1] + if strings.HasPrefix(message, "done;") { + return + } + + // Set severity. + s = strings.Split(message, ":") + switch s[0] { + case "OK": + case "WARNING": + if status < 1 { + status = 1 + } + case "CRITICAL": + status = 2 + case "UNKNOWN": + status = 3 + default: + fmt.Printf("CRITICAL: Unexpected server response: '%s'\n", message) + status = 2 + return + } + fmt.Print(message) + + case <-time.After(time.Second * 2): + status = 2 + fmt.Println("CRITICAL: Could not communicate with DTail server") + return + } + } +} + +// Initialize SSH auth methods. +func (c *HealthClient) initSSHAuthMethods() { + c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.ControlUser)) +} diff --git a/clients/maprclient.go b/clients/maprclient.go new file mode 100644 index 0000000..ad707c9 --- /dev/null +++ b/clients/maprclient.go @@ -0,0 +1,153 @@ +package clients + +import ( + "dtail/clients/handlers" + "dtail/clients/remote" + "dtail/logger" + "dtail/mapr" + "dtail/omode" + "dtail/ssh/client" + "errors" + "fmt" + "strings" + "sync" + "time" + + gossh "golang.org/x/crypto/ssh" +) + +// MaprClient is used for running mapreduce aggregations on remote files. +type MaprClient struct { + baseClient + // Query string for mapr aggregations + queryStr string + // Global group set for merged mapr aggregation results + globalGroup *mapr.GlobalGroupSet + // The query object (constructed from queryStr) + query *mapr.Query + // Additative result or new result every run? + additative bool +} + +// NewMaprClient returns a new mapreduce client. +func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { + if queryStr == "" { + return nil, errors.New("No mapreduce query specified, use '-query' flag") + } + + c := MaprClient{ + baseClient: baseClient{ + Args: args, + stop: make(chan struct{}), + stopped: make(chan struct{}), + throttleCh: make(chan struct{}, args.MaxInitConnections), + retry: args.Mode == omode.TailClient, + }, + queryStr: queryStr, + additative: args.Mode == omode.MapClient, + } + + query, err := mapr.NewQuery(c.queryStr) + if err != nil { + logger.FatalExit(c.queryStr, "Can't parse mapr query", err) + } + + c.query = query + + switch c.query.Table { + case "*": + c.Regex = fmt.Sprintf("\\|MAPREDUCE:\\|") + case ".": + c.Regex = "." + default: + c.Regex = fmt.Sprintf("\\|MAPREDUCE:%s\\|", c.query.Table) + } + + c.globalGroup = mapr.NewGlobalGroupSet() + c.baseClient.init(c) + + return &c, nil +} + +func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout) + + conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery)) + commandStr := "tail" + if c.additative { + commandStr = "cat" + } + + for _, file := range strings.Split(c.Files, ",") { + conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex)) + } + + return conn +} + +// Start starts the mapreduce client. +func (c *MaprClient) Start(wg *sync.WaitGroup) (status int) { + defer wg.Done() + + if c.query.Outfile == "" { + // Only print out periodic results if we don't write an outfile + go c.periodicPrintResults() + } + + status = c.baseClient.Start(nil) + if c.additative { + c.recievedFinalResult() + } + c.baseClient.Stop() + + return +} + +func (c *MaprClient) recievedFinalResult() { + logger.Info("Received final mapreduce result") + + if c.query.Outfile == "" { + c.printResults() + return + } + + logger.Info(fmt.Sprintf("Writing final mapreduce result to '%s'", c.query.Outfile)) + err := c.globalGroup.WriteResult(c.query) + if err != nil { + logger.FatalExit(err) + return + } + logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile)) +} + +func (c *MaprClient) periodicPrintResults() { + for { + select { + case <-time.After(c.query.Interval): + logger.Info("Gathering interim mapreduce result") + c.printResults() + case <-c.baseClient.stop: + return + } + } +} + +func (c *MaprClient) printResults() { + var result string + var err error + var numLines int + + if c.additative { + result, numLines, err = c.globalGroup.Result(c.query) + } else { + result, numLines, err = c.globalGroup.SwapOut().Result(c.query) + } + if err != nil { + logger.FatalExit(err) + } + if numLines > 0 { + logger.Raw(fmt.Sprintf("%s\n", c.query.RawQuery)) + logger.Raw(result) + } +} diff --git a/clients/remote/connection.go b/clients/remote/connection.go new file mode 100644 index 0000000..bd93239 --- /dev/null +++ b/clients/remote/connection.go @@ -0,0 +1,230 @@ +package remote + +import ( + "dtail/clients/handlers" + "dtail/config" + "dtail/logger" + "dtail/ssh/client" + "fmt" + "io" + "strconv" + "strings" + "time" + + "golang.org/x/crypto/ssh" +) + +// Connection represents a client connection connection to a single server. +type Connection struct { + // The remote server's hostname connected to. + Server string + // The remote server's port connected to. + port int + // The SSH client configuration used. + config *ssh.ClientConfig + // The SSH client handler to use. + Handler handlers.Handler + // DTail commands sent from client to server. When client loses + // connection to the server it re-connects automatically and sends the + // same commands again. + Commands []string + // Is it a persistent connection or a one-off? + isOneOff bool + // Used to stop the connection + stop chan struct{} + // To deal with SSH server host keys + hostKeyCallback *client.HostKeyCallback +} + +// NewConnection returns a new connection. +func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *Connection { + logger.Debug(server, "Creating new connection") + + c := Connection{ + hostKeyCallback: hostKeyCallback, + config: &ssh.ClientConfig{ + User: userName, + Auth: authMethods, + HostKeyCallback: hostKeyCallback.Wrap(), + Timeout: time.Second * 3, + }, + stop: make(chan struct{}), + } + + c.initServerPort(server) + + return &c +} + +// NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit). +func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection { + c := Connection{ + config: &ssh.ClientConfig{ + User: userName, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }, + stop: make(chan struct{}), + isOneOff: true, + } + + c.initServerPort(server) + + return &c +} + +// Attempt to parse the server port address from the provided server FQDN. +func (c *Connection) initServerPort(server string) { + c.Server = server + c.port = config.Common.SSHPort + parts := strings.Split(server, ":") + + if len(parts) == 2 { + logger.Debug("Parsing port from hostname", parts) + port, err := strconv.Atoi(parts[1]) + if err != nil { + logger.FatalExit("Unable to parse client port", server, parts, err) + } + c.Server = parts[0] + c.port = port + } +} + +// Start the server connection. Build up SSH session and send some DTail commandc. +func (c *Connection) Start(throttleCh, statsCh chan struct{}) { + select { + case <-c.stop: + logger.Info(c.Server, c.port, "Disconnecting client") + return + default: + } + + // Wait for SSH connection throttler + throttleCh <- struct{}{} + + // Wait until connection has been initiated or an error occured + // during initialization. + throttleStopCh := make(chan struct{}, 2) + go func() { + <-throttleStopCh + <-throttleCh + }() + + if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil { + logger.Warn(c.Server, c.port, err) + throttleStopCh <- struct{}{} + + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { + logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port) + return + } + } +} + +// Dail into a new SSH connection. Close connection in case of an error. +func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error { + statsCh <- struct{}{} + defer func() { <-statsCh }() + + logger.Debug(host, "dial") + address := fmt.Sprintf("%s:%d", host, port) + + client, err := ssh.Dial("tcp", address, c.config) + if err != nil { + return err + } + defer client.Close() + + return c.session(client, throttleStopCh) +} + +// Create the SSH session. Close the session in case of an error. +func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error { + logger.Debug(c.Server, "session") + + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + + return c.handle(session, throttleStopCh) +} + +// Handle the SSH session. Also send periodic pings to the server in order +// to determine that session is still intact. +func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error { + defer c.Handler.Stop() + + logger.Debug(c.Server, "handle") + + stdinPipe, err := session.StdinPipe() + if err != nil { + return err + } + + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return err + } + + if err := session.Shell(); err != nil { + return err + } + + // Establish Bi-directional pipe between SSH session and client handler. + brokenStdinPipe := make(chan struct{}) + go func() { + defer close(brokenStdinPipe) + io.Copy(stdinPipe, c.Handler) + }() + + brokenStdoutPipe := make(chan struct{}) + go func() { + defer close(brokenStdoutPipe) + io.Copy(c.Handler, stdoutPipe) + }() + + // SSH session established, other goroutine can initiate session now. + throttleStopCh <- struct{}{} + + // Send all commands to client. + for _, command := range c.Commands { + logger.Debug(command) + c.Handler.SendCommand(command) + } + + if !c.isOneOff { + return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe) + } + + <-c.stop + + // Normal shutdown, all fine + return nil +} + +// Periodically check whether connection is still alive or not. +func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error { + for { + select { + case <-time.After(time.Second * 3): + if err := c.Handler.Ping(); err != nil { + return err + } + case <-brokenStdinPipe: + logger.Debug("Broken stdin pipe", c.Server, c.port) + return nil + case <-brokenStdoutPipe: + logger.Debug("Broken stdout pipe", c.Server, c.port) + return nil + case <-c.stop: + return nil + } + } +} + +// Stop the connection. +func (c *Connection) Stop() { + close(c.stop) +} diff --git a/clients/stats.go b/clients/stats.go new file mode 100644 index 0000000..e5b9bed --- /dev/null +++ b/clients/stats.go @@ -0,0 +1,81 @@ +package clients + +import ( + "dtail/logger" + "fmt" + "runtime" + "sync" + "time" +) + +// Used to collect and display various client stats. +type stats struct { + // Total amount servers to connect to. + connectionsTotal int + // To keep track of what connected and disconnected + connectionsEstCh chan struct{} + // Amount of servers connections are established. + connected int + // To synchronize concurrent access. + mutex sync.Mutex +} + +func newTailStats(connectionsTotal int) *stats { + return &stats{ + connectionsTotal: connectionsTotal, + connectionsEstCh: make(chan struct{}, connectionsTotal), + connected: 0, + } +} + +func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) { + connectedLast := 0 + statsInterval := 5 + + for { + select { + case <-time.After(time.Second * time.Duration(statsInterval)): + case <-stop: + return + } + + connected := len(s.connectionsEstCh) + throttle := len(throttleCh) + + newConnections := connected - connectedLast + connectionsPerSecond := float64(newConnections) / float64(statsInterval) + s.log(connected, newConnections, connectionsPerSecond, throttle) + + connectedLast = connected + + s.mutex.Lock() + s.connected = connected + s.mutex.Unlock() + } +} + +func (s *stats) numConnected() int { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.connected +} + +func (s *stats) log(connected, newConnections int, connectionsPerSecond float64, throttle int) { + percConnected := percentOf(float64(s.connectionsTotal), float64(connected)) + + connectedStr := fmt.Sprintf("connected=%d/%d(%d%%)", connected, s.connectionsTotal, int(percConnected)) + newConnStr := fmt.Sprintf("new=%d", newConnections) + rateStr := fmt.Sprintf("rate=%2.2f/s", connectionsPerSecond) + throttleStr := fmt.Sprintf("throttle=%d", throttle) + cpusGoroutinesStr := fmt.Sprintf("cpus/goroutines=%d/%d", runtime.NumCPU(), runtime.NumGoroutine()) + + logger.Info("stats", connectedStr, newConnStr, rateStr, throttleStr, cpusGoroutinesStr) +} + +func percentOf(total float64, value float64) float64 { + if total == 0 || total == value { + return 100 + } + return value / (total / 100.0) +} diff --git a/clients/tailclient.go b/clients/tailclient.go new file mode 100644 index 0000000..cb93258 --- /dev/null +++ b/clients/tailclient.go @@ -0,0 +1,44 @@ +package clients + +import ( + "dtail/clients/handlers" + "dtail/clients/remote" + "dtail/ssh/client" + "fmt" + "strings" + + gossh "golang.org/x/crypto/ssh" +) + +// TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines). +type TailClient struct { + baseClient +} + +// NewTailClient returns a new TailClient. +func NewTailClient(args Args) (*TailClient, error) { + c := TailClient{ + baseClient: baseClient{ + Args: args, + stop: make(chan struct{}), + stopped: make(chan struct{}), + throttleCh: make(chan struct{}, args.MaxInitConnections), + retry: true, + }, + } + + c.init(c) + + return &c, nil +} + +func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) + + for _, file := range strings.Split(c.Files, ",") { + conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) + } + + return conn +} diff --git a/color/color.go b/color/color.go new file mode 100644 index 0000000..64e0d7f --- /dev/null +++ b/color/color.go @@ -0,0 +1,75 @@ +// Package color is used to prettify console output via ANSII terminal colors. +package color + +import ( + "fmt" +) + +// Color name. +type Color string + +// Attribute of a color. +type Attribute string + +// The possible color variations. +const ( + escape = "\x1b" + reset = escape + "[0m" + seq string = "%s%s%s" + + Gray Color = escape + "[30m" + Red Color = escape + "[31m" + Green Color = escape + "[32m" + Orange Color = escape + "[33m" + Blue Color = escape + "[34m" + Magenta Color = escape + "[35m" + Yellow Color = escape + "[36m" + LightGray Color = escape + "[37m" + + BgGray Color = escape + "[40m" + BgRed Color = escape + "[41m" + BgGreen Color = escape + "[42m" + BgOrange Color = escape + "[43m" + BgBlue Color = escape + "[44m" + BgMagenta Color = escape + "[45m" + BgYellow Color = escape + "[46m" + BgLightGray Color = escape + "[47m" + + Bold Attribute = escape + "[1m" + Italic Attribute = escape + "[3m" + Underline Attribute = escape + "[4m" + ReverseColor Attribute = escape + "[7m" + + resetBold = escape + "[22m" + resetItalic = escape + "[23m" + resetUnderline = escape + "[24m" + + Test Color = BgYellow + TestAttr Attribute = Bold +) + +// Colored DTail client output enabled. +var Colored bool + +// Init whether we want colored output or not. +func Init(colored bool) { + Colored = colored +} + +// Paint a given string in a given color. +func Paint(c Color, s string) string { + return fmt.Sprintf(seq, c, s, reset) +} + +// Attr adds a given attribute to a given string, such as "bold" or "italic". +func Attr(c Attribute, s string) string { + switch c { + case Bold: + return fmt.Sprintf(seq, Bold, s, resetBold) + case Italic: + return fmt.Sprintf(seq, Italic, s, resetItalic) + case Underline: + return fmt.Sprintf(seq, Underline, s, resetUnderline) + } + panic("Unknown attribute") +} diff --git a/color/colorfy.go b/color/colorfy.go new file mode 100644 index 0000000..9ae46f5 --- /dev/null +++ b/color/colorfy.go @@ -0,0 +1,58 @@ +package color + +import ( + "fmt" + "strings" +) + +// Add some color to log lines received from remote servers. +func paintRemote(line string) string { + splitted := strings.Split(line, "|") + if splitted[2] == "100" { + splitted[2] = Paint(BgGreen, splitted[2]) + } else { + splitted[2] = Paint(BgRed, splitted[2]) + } + info := strings.Join(splitted[0:5], "|") + log := strings.Join(splitted[5:], "|") + + if strings.HasPrefix(log, "WARN") { + log = Paint(BgYellow, log) + } else if strings.HasPrefix(log, "ERROR") { + log = Paint(BgRed, log) + } else if strings.HasPrefix(log, "FATAL") { + log = Attr(Bold, Paint(BgRed, log)) + } else { + log = Paint(Blue, log) + } + + return fmt.Sprintf("%s|%s", info, log) +} + +// Add some color to stats generated by the client. +func paintClientStats(line string) string { + splitted := strings.Split(line, "|") + first := strings.Join(splitted[0:4], "|") + connected := Paint(BgBlue, splitted[4]) + last := strings.Join(splitted[5:], "|") + + return fmt.Sprintf("%s|%s|%s", first, connected, last) +} + +// Colorfy a given line based on the line's content. +func Colorfy(line string) string { + if strings.HasPrefix(line, "REMOTE") { + return paintRemote(line) + } + if strings.HasPrefix(line, "CLIENT") && strings.Contains(line, "|stats|") { + return paintClientStats(line) + } + if strings.Contains(line, "ERROR") { + return Paint(Magenta, line) + } + if strings.Contains(line, "WARN") { + return Paint(Magenta, line) + } + + return line +} diff --git a/config/client.go b/config/client.go new file mode 100644 index 0000000..1515aae --- /dev/null +++ b/config/client.go @@ -0,0 +1,11 @@ +package config + +// ClientConfig represents a DTail client configuration (empty as of now as there +// are no available config options yet, but that may changes in the future). +type ClientConfig struct { +} + +// Create a new default client configuration. +func newDefaultClientConfig() *ClientConfig { + return &ClientConfig{} +} diff --git a/config/common.go b/config/common.go new file mode 100644 index 0000000..8c07710 --- /dev/null +++ b/config/common.go @@ -0,0 +1,42 @@ +package config + +// CommonConfig stores configuration keys shared by DTail server and client. +type CommonConfig struct { + // The SSH server port number. + SSHPort int + // Enable experimental features. + ExperimentalFeaturesEnable bool `json:",omitempty"` + // Enable extra debug logging (used for deevlopment or debugging purpes only). + DebugEnable bool `json:",omitempty"` + // Enable extra trace logging (used for deevlopment or debugging purpes only). + TraceEnable bool `json:",omitempty"` + // The log strategy to use, one of + // stdout: only log to stdout (useful when used with systemd) + // daily: create a log file for every day + LogStrategy string + // The log directory + LogDir string + // The cache directory + CacheDir string + // Do we want to enable pperf http server? + PProfEnable bool `json:",omitempty"` + // The HTTP port used by PProf + PProfPort int `json:",omitempty"` + // The PProf HTTP server bind address + PProfBindAddress string `json:",omitempty"` +} + +// Create a new default configuration. +func newDefaultCommonConfig() *CommonConfig { + return &CommonConfig{ + SSHPort: 2222, + DebugEnable: false, + TraceEnable: false, + ExperimentalFeaturesEnable: false, + LogDir: "log", + CacheDir: "cache", + PProfEnable: false, + PProfPort: 6060, + PProfBindAddress: "0.0.0.0", + } +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..5463c5f --- /dev/null +++ b/config/config.go @@ -0,0 +1,72 @@ +package config + +import ( + "encoding/json" + "io/ioutil" + "os" +) + +// ControlUser is used for various DTail specific operations. +const ControlUser string = "DTAIL-CONTROL-USER" + +// Client holds a DTail client configuration. +var Client *ClientConfig + +// Server holds a DTail server configuration. +var Server *ServerConfig + +// Common holds common configs of both both, client and server. +var Common *CommonConfig + +// Used to initialize the configuration. +type configInitializer struct { + Common *CommonConfig + Server *ServerConfig + Client *ClientConfig +} + +// Parse and read a given config file in JSON format. +func (c *configInitializer) parseConfig(configFile string) { + fd, err := os.Open(configFile) + if err != nil { + panic(err) + } + defer fd.Close() + + cfgBytes, err := ioutil.ReadAll(fd) + if err != nil { + panic(err) + } + + err = json.Unmarshal([]byte(cfgBytes), c) + if err != nil { + panic(err) + } +} + +// Init the DTail configuration. +func Init(configFile string) { + initializer := configInitializer{ + Common: newDefaultCommonConfig(), + Server: newDefaultServerConfig(), + Client: newDefaultClientConfig(), + } + + if configFile == "" { + configFile = "./cfg/dtail.json" + } + + if _, err := os.Stat(configFile); !os.IsNotExist(err) { + initializer.parseConfig(configFile) + } + + // Assign pointers to global variables, so that we can access the + // configuration from any place of the program. + Common = initializer.Common + Server = initializer.Server + Client = initializer.Client + + if Server.MapreduceLogFormat == "" { + Server.MapreduceLogFormat = "default" + } +} diff --git a/config/server.go b/config/server.go new file mode 100644 index 0000000..7883b33 --- /dev/null +++ b/config/server.go @@ -0,0 +1,66 @@ +package config + +import ( + "errors" +) + +// Permissions map. Each SSH user has a list of permissions which +// log files it is allowed to follow and which ones not. +type Permissions struct { + // The default user permissions. + Default []string + // The per user special permissions. + Users map[string][]string +} + +// ServerConfig represents the server configuration. +type ServerConfig struct { + // The SSH server bind port. + SSHBindAddress string + // The max amount of concurrent user connection allowed to connect to the server. + MaxConnections int + // The max amount of concurrent cats per server. + MaxConcurrentCats int + // The max amount of concurrent tails per server. + MaxConcurrentTails int + // The user permissions. + Permissions Permissions `json:",omitempty"` + // The mapr log format + MapreduceLogFormat string `json:",omitempty"` + // The default path of the server host key + HostKeyFile string + // The host key size in bits + HostKeyBits int +} + +// Create a new default server configuration. +func newDefaultServerConfig() *ServerConfig { + defaultPermissions := []string{"^/.*"} + defaultBindAddress := "0.0.0.0" + + return &ServerConfig{ + SSHBindAddress: defaultBindAddress, + MaxConnections: 10, + MaxConcurrentCats: 2, + MaxConcurrentTails: 50, + HostKeyFile: "./cache/ssh_host_key", + HostKeyBits: 4096, + Permissions: Permissions{ + Default: defaultPermissions, + }, + } +} + +// ServerUserPermissions retrieves the permission set of a given user. +func ServerUserPermissions(userName string) (permissions []string, err error) { + permissions = Server.Permissions.Default + if p, ok := Server.Permissions.Users[userName]; ok { + permissions = p + } + + if len(permissions) == 0 { + err = errors.New("Empty set of permission, user won't be able to open any files") + } + + return +} diff --git a/discovery/comma.go b/discovery/comma.go new file mode 100644 index 0000000..c7c9d75 --- /dev/null +++ b/discovery/comma.go @@ -0,0 +1,12 @@ +package discovery + +import ( + "dtail/logger" + "strings" +) + +// ServerListFromCOMMA retrieves a list of servers from comma separated input list. +func (d *Discovery) ServerListFromCOMMA() []string { + logger.Debug("Retrieving server list from comma separated list", d.server) + return strings.Split(d.server, ",") +} diff --git a/discovery/discovery.go b/discovery/discovery.go new file mode 100644 index 0000000..4150fd9 --- /dev/null +++ b/discovery/discovery.go @@ -0,0 +1,173 @@ +package discovery + +import ( + "dtail/logger" + "fmt" + "math/rand" + "os" + "reflect" + "regexp" + "strings" + "time" +) + +// Discovery method for discovering a list of available DTail servers. +type Discovery struct { + // To plug in a custom server discovery module. + module string + // To specifiy optional server discovery module options. + options string + // To either filter a server list or to secify an exact list. + server string + // To filter server list. + regex *regexp.Regexp + // To shuffle resulting server list. + shuffle bool +} + +// New returns a new discovery method. +func New(method, server string, shuffle bool) *Discovery { + module := method + options := "" + + if strings.Contains(module, ":") { + s := strings.Split(module, ":") + if len(s) != 2 { + logger.FatalExit("Unable to parse discovery module", module) + } + module = s[0] + options = s[1] + } + + d := Discovery{ + module: strings.ToUpper(module), + options: options, + server: server, + shuffle: shuffle, + } + + if strings.HasPrefix(server, "/") && strings.HasSuffix(server, "/") { + d.initRegex() + } + + return &d +} + +func (d *Discovery) initRegex() { + var runes []rune + last := len(d.server) - 1 + for i, char := range d.server { + if i != 0 && i != last { + runes = append(runes, char) + } + } + + regexStr := string(runes) + logger.Debug("Using filter regex", regexStr) + + regex, err := regexp.Compile(regexStr) + if err != nil { + logger.FatalExit("Could not compile regex", regexStr, err) + } + + d.regex = regex + d.server = "" +} + +// ServerList to connect to via DTail client. +func (d *Discovery) ServerList() []string { + servers := d.serverListFromModule() + + if d.regex != nil { + servers = d.filterList(servers) + } + + servers = d.dedupList(servers) + + if d.shuffle { + servers = d.shuffleList(servers) + } + + logger.Debug("Discovered servers", len(servers), servers) + return servers +} + +func (d *Discovery) serverListFromModule() []string { + if d.module != "" { + return d.serverListFromReflectedModule() + } + + if _, err := os.Stat(d.server); err == nil { + // Appears to be a file name, now try to read from that file. + return d.ServerListFromFILE() + } + + // Appears to be a list of FQDNs (or a single FQDN) + return d.ServerListFromCOMMA() +} + +// The aim of this is that everyone can plug in their own server discovery +// method to DTail. Just add a method ServerListFrommMODULENAME to type +// Discovery. Whereas MODULENAME must be a upeprcase string. +func (d *Discovery) serverListFromReflectedModule() []string { + methodName := fmt.Sprintf("ServerListFrom%s", d.module) + + rt := reflect.TypeOf(d) + reflectedMethod, ok := rt.MethodByName(methodName) + if !ok { + logger.FatalExit("No such server discovery module", d.module, methodName) + } + + inputValues := make([]reflect.Value, 1) + // Thist input value is method receiver. + inputValues[0] = reflect.ValueOf(d) + returnValues := reflectedMethod.Func.Call(inputValues) + + // First return value is server list. + return returnValues[0].Interface().([]string) +} + +// Filter server list based on a regexp. +func (d *Discovery) filterList(servers []string) (filtered []string) { + logger.Debug("Filtering server list") + + for _, server := range servers { + if d.regex.MatchString(server) { + filtered = append(filtered, server) + } + } + + return +} + +// Deduplicate the server list. +func (d *Discovery) dedupList(servers []string) (deduped []string) { + serverMap := make(map[string]struct{}, len(servers)) + + for _, server := range servers { + if _, ok := serverMap[server]; !ok { + serverMap[server] = struct{}{} + deduped = append(deduped, server) + } + } + + logger.Info("Deduped server list", len(servers), len(deduped)) + return +} + +// Randomly shuffle the server list. +func (d *Discovery) shuffleList(servers []string) []string { + logger.Debug("Shuffling server list") + + r := rand.New(rand.NewSource(time.Now().Unix())) + shuffled := make([]string, len(servers)) + n := len(servers) + + for i := 0; i < n; i++ { + randIndex := r.Intn(len(servers)) + shuffled[i] = servers[randIndex] + servers = append(servers[:randIndex], servers[randIndex+1:]...) + } + + return shuffled +} diff --git a/discovery/file.go b/discovery/file.go new file mode 100644 index 0000000..e02d6b4 --- /dev/null +++ b/discovery/file.go @@ -0,0 +1,28 @@ +package discovery + +import ( + "bufio" + "dtail/logger" + "os" +) + +// ServerListFromFILE retrieves a list of servers from a file. +func (d *Discovery) ServerListFromFILE() (servers []string) { + logger.Debug("Retrieving server list from file", d.server) + + file, err := os.Open(d.server) + if err != nil { + logger.FatalExit(d.server, err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + servers = append(servers, scanner.Text()) + } + if err := scanner.Err(); err != nil { + logger.FatalExit(d.server, err) + } + + return +} diff --git a/doc/dcat.gif b/doc/dcat.gif Binary files differnew file mode 100644 index 0000000..a6eae0f --- /dev/null +++ b/doc/dcat.gif diff --git a/doc/dgrep.gif b/doc/dgrep.gif Binary files differnew file mode 100644 index 0000000..e2f2ac6 --- /dev/null +++ b/doc/dgrep.gif diff --git a/doc/dmap.gif b/doc/dmap.gif Binary files differnew file mode 100644 index 0000000..eb75e8d --- /dev/null +++ b/doc/dmap.gif diff --git a/doc/dtail-map.gif b/doc/dtail-map.gif Binary files differnew file mode 100644 index 0000000..c1a74fe --- /dev/null +++ b/doc/dtail-map.gif diff --git a/doc/dtail.gif b/doc/dtail.gif Binary files differnew file mode 100644 index 0000000..8f6b56b --- /dev/null +++ b/doc/dtail.gif diff --git a/doc/examples.md b/doc/examples.md new file mode 100644 index 0000000..78a9caf --- /dev/null +++ b/doc/examples.md @@ -0,0 +1,67 @@ +Examples +======== + +This page demonstrate the basic usage of DTail. Please also see ``dtail --help`` for more available options. + +# How to use ``dtail`` + +## Tailing logs + +The following example demonstrates how to follow logs of multiple servers at once. The server list is provided as a flat text file. The example filters all logs containing the string ``STAT``. Any other Go compatible regular expression can be used instead of ``STAT``. + +```shell +workstation01 ~ % dtail --servers serverlist.txt --files "/var/log/service/*.log" --regex STAT +``` + + + +## Aggregating logs + +To run ad-hoc mapreduce aggregations on newly written log lines you also must add a query. This example follows all remote log lines and prints out every 5 seconds the top 10 servers with most average free memory according to the logs. To run a mapreduce query across log lines written in the past please use the ``dmap`` command instead. + +```shell +workstation01 ~ % dtail --servers serverlist.txt \ + --query 'select avg(memfree), $hostname from MCVMSTATS group by $hostname order by avg(memfree) limit 10 interval 5' \ + --files '/var/log/service/*.log' +``` + +In order for mapreduce queries to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``mapr/logformat`` or add an extension to support a custom log format. + + + +# How to use ``dcat`` + +The following example demonstrates how to cat files (display the whole content of the files) of multiple servers at once. The servers are provided as a comma separated list this time. + +```shell +workstation01 ~ % dcat --servers serv-011.lan.example.org,serv-012.lan.example.org,serv-013.lan.example.org \ + --files /etc/hostname +``` + + + +# How to use ``dgrep`` + +The following example demonstrates how to grep files (display only the lines which match a given regular expression) of multiple servers at once. In this example we look after the swap partition in ``/etc/fstab``. We do that only on the first 20 servers from ``serverlist.txt``. ``dgrep`` is also very useful for searching log files of the past. + +```shell +workstation01 ~ % dgrep --servers <(head -n 20 serverlist.txt) \ + --files /etc/fstab \ + --regex swap +``` + + + +# How to use ``dmap`` + +To run a mapreduce aggregation over logs written in the past the ``dmap`` command can be used. For example the following command aggregates all mapreduce fields of all the logs and calculates the average memory free grouped by day of the month, hour, minute and the server hostname. ``dmap`` will print interim results every few seconds. The final result however will be written to file ``mapreduce.csv``. + +```shell +dmap --servers serv-011.lan.example.org,serv-012.lan.example.org,serv-013.lan.example.org,serv-021.lan.example.org,serv-022.lan.example.org,serv-023.lan.example.org \ + --query 'select avg(memfree), $day, $hour, $minute, $hostname from MCVMSTATS group by $day, $hour, $minute, $hostname order by avg(memfree) limit 10 outfile mapreduce.csv' \ + --files "/var/log/service/*.log" +``` + +Remember: In order for that to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``mapr/logformat`` or add an extension to support a custom log format. + + diff --git a/doc/installation.md b/doc/installation.md new file mode 100644 index 0000000..d1ecf9f --- /dev/null +++ b/doc/installation.md @@ -0,0 +1,83 @@ +DTail Installation Guide +======================== + +The following installation guide has been tested successfully on CentOS 7. You may need to adjust accordingly depending on the distribution you use. + +This guide also assumes that you know how to use ``systemd`` and how to configure a service there. If you are unsure please consult the documentation of your distribution. + +This guide also assumes that you know how to add a new Nagios check to your monitoring infrastructure. + +# Compile it + +Please check the [Quick Starting Guide](quickstart.md) for instructions how to compile DTail. It is recommended to automate the build process via your build pipeline (e.g. produce a deployable RPM via Jenkins). But that is out of scope of this documentation. + +# Install it + +It is recommended to automate all the installation process outlined here. You could use a configuration management system such as Puppet, Chef or Ansible. However, that relies heavily on how your infrastructure is managed and is out of scope of this documentation. +1. The ``dserver`` binary has to be installed on all machines (server boxes) involved. A good location for the binary would be ``/usr/local/bin/dserver`` with permissions set as follows: + +```console +serv-001 ~ % sudo chown root:root /usr/local/bin/dserver +serv-001 ~ % sudo chmod 0755 /usr/local/bin/dserver +``` + +2. Create the ``dserver`` run user and group. The user could look like this: + +```console +serv-001 ~ % id dserver +uid=670(dserver) gid=670(dserver) groups=670(dserver) +``` + +3. Create the required file system structure and set the correct permissions: + +```console +serv-001 ~ % sudo mkdir -p /etc/dserver /var/run/dserver +serv-001 ~ % sudo chown -R dserver:dserver /var/run/dserver +``` + +4. Install the ``dtail.json`` config to ``/etc/dserver/dtail.json``. An example can be found [here](../samples/dtail.json.sample). + +5. It is recommended to configure DTail server as a service to ``systemd``. An example unit file for ``systemd`` can be found [here](../samples/dserver.service.sample). + +# Start it + +To start the DTail server via ``systemd`` run: + +```console +serv-001 ~ % sudo systemctl start dserver +serv-001 ~ % sudo systemctl status dserver +● dserver.service - DTail server + Loaded: loaded (/etc/systemd/system/dserver.service; disabled; vendor preset: disabled) + Active: active (running) since Fri 2019-12-06 13:21:24 GMT; 2s ago + Main PID: 12296 (dserver) + Memory: 1.5M + CGroup: /dserver.slice/dserver.service + └─12296 /usr/local/bin/dserver -cfg /etc/dserver/dtail.json + + Dec 06 13:21:24 serv-001.lan.example.org systemd[1]: Started DTail server. + Dec 06 13:21:24 serv-001.lan.example.org dserver[12296]: SERVER|serv-001|INFO|Launching server|server|DTail 1.0.0 + Dec 06 13:21:24 serv-001.lan.example.org dserver[12296]: SERVER|serv-001|INFO|Creating server|DTail 1.0.0 + Dec 06 13:21:24 serv-001.lan.example.org dserver[12296]: SERVER|serv-001|INFO|Reading private server RSA host key from file|cache/ssh_host_key + Dec 06 13:21:24 serv-001.lan.example.org dserver[12296]: SERVER|serv-001|INFO|Starting server + Dec 06 13:21:24 serv-001.lan.example.org dserver[12296]: SERVER|serv-001|INFO|Binding server|1.2.3.4:2222 +``` + +# Register SSH public keys in DTail server + +The DTail server now runs as a ``systemd`` service under system user ``dserver``. The system user ``dserver`` however has no permissions to read the SSH public keys from ``/home/USER/.ssh/authorized_keys``. Therefore, no user would be able to establish a SSH session to DTail server. As an alternative path DTail server also checks for public SSH key files in ``/var/run/dserver/cache/USER.authorized_keys``. + +It is recommended to execute [update_key_cache.sh](../samples/update_key_cache.sh.sample) periodically to update the key cache. In case you manage your public SSH keys via Puppet you could subscribe the script to corresponding module. Or alternatively just configure a cron job to run every once in a while. + +# Run DTail client + +Now you should be able to use DTail client like outlined in the [Quick Starting Guide](quickstart.md). Also have a look at the [Examples](examples.md). + +# Monitor it + +To verify that DTail server is up and running and functioning as expected you should configure the Nagios check [check_dserver.sh](../samples/check_dserver.sh.sample) in your monitoring system. The check has to be executed locally on the server (e.g. via NRPE). How to configure the monitoring system in detail is out of scope of this guide, as it depends on the monitoring infrastructure used. + +```console +% ./check_dserver.sh +OK: DTail SSH Server seems fine +``` + diff --git a/doc/logo.png b/doc/logo.png Binary files differnew file mode 100644 index 0000000..38edb16 --- /dev/null +++ b/doc/logo.png diff --git a/doc/logo.webp b/doc/logo.webp Binary files differnew file mode 100644 index 0000000..5160d5b --- /dev/null +++ b/doc/logo.webp diff --git a/doc/quickstart.md b/doc/quickstart.md new file mode 100644 index 0000000..57432c0 --- /dev/null +++ b/doc/quickstart.md @@ -0,0 +1,99 @@ +Quick Starting Guide +==================== + +This is the quick starting guide. For a more sustainable setup, involving how to create a background service via ``systemd``, recommendations about automation via Jenkins and/or Puppet and health monitoring via Nagios please also follow the [Installation Guide](installation.md). + +This guide assumes that you know how to generate and configure a public/private SSH key pair for secure authorization and shell access. That is out of scope of this guide. For more information please have a look at the OpenSSH documentation of your distribution. + +This guide also assumes that you know how to install and use a Go compiler and GNU make. + +# Compile it + +To produce all DTail binaries run ``make``: + +```console +workstation01 ~/git/dtail % make +go build +cp -pv ./dtail ./dcat +./dtail -> ./dcat +cp -pv ./dtail ./dgrep +./dtail -> ./dgrep +cp -pv ./dtail ./dmap +./dtail -> ./dmap +cp -pv ./dtail ./dserver +./dtail -> ./dserver +``` + +It produces the following executables: + +* ``dserver``: The DTail server +* ``dtail``: Client for tailing/following log files remotely (distributed tail) +* ``dcat``: Client for displaying whole files remotely (distributed cat) +* ``dgrep``: Client for searching whole files files remotely using a regex (distributed grep) +* ``dmap``: Client for executing distributed mapreduce queries (may will consume a lot of RAM and CPU) + +# Start DTail server + +Copy the ``dserver`` binary to the remote server machines of your choice (e.g. ``serv-001.lan.example.org`` and ``serv-002.lan.example.org``) and start it on each of the servers as follows: + +```console +serv-001 ~ % ./dserver +SERVER|serv-001|INFO|Launching server|server|DTail 1.0.0 +SERVER|serv-001|INFO|Creating server|DTail 1.0.0 +SERVER|serv-001|INFO|Generating private server RSA host key +SERVER|serv-001|INFO|Starting server +SERVER|serv-001|INFO|Binding server|0.0.0.0:2222 +``` + +``dserver`` is now listening on TCP port 2222 and waiting for incoming connections. All SSH keys listed in ``~/.ssh/authorized_keys`` are now respected by the DTail server for authorization. + +# Setup DTail client + +## Setup SSH + +Make sure that your public SSH key is listed in ``~/.ssh/authorized_keys`` on all server machines involved. The private SSH key counterpart should preferably stay on your Laptop or workstation in ``~/.ssh/id_rsa`` or ``~/.ssh/id_dsa``. + +DTail utilises the SSH Agent for SSH authentication. This is to avoid entering the passphrase of the private SSH key over and over again when a new SSH session is initiated from the DTail client to a new DTail server. For this the private SSH key has to be registered at the SSH Agent: + +```console +workstation01 ~ % ssh-add ~/.ssh/id_rsa +Enter passphrase for ~/.ssh/id_rsa: ********** +Identity added: ~/.ssh/id_rsa (~/.ssh/id_rsa) +``` + +The DTail client communicates with the SSH Agent through ``~/.ssh/ssh_auth_socket`` whenever a new session to a DTail server is established. + +To test whether SSH is setup correctly you should be able to SSH into the servers with the OpenSSH client and your private SSH key through the SSH Agent without entering the private keys passphrase. The following assumes to have an OpenSSH server running on ``serv-001.lan.example.org`` and an OpenSSH client installed on your laptop or workstation. Please notice that DTail does not require to have an OpenSSH infrastructure set up but DTail uses by default the same public/private key file paths as OpenSSH. OpenSSH can be of a great help to verify that the SSH keys are configured correctly: + +```console +workstation01 ~/git/dtail % ssh serv-001.lan.example.org +serv-001 ~ % +serv-001 ~ % exit +workstation01 ~/git/dtail % +``` + +## Run DTail client + +Now it is time to connect to the DTail servers through the DTail client: + +```console +workstation01 ~/git/dtail % ./bin/dtail --servers serv-001.lan.example.org,server-002.lan.example.org --files "/var/log/service/*.log" +CLIENT|workstation01|INFO|Launching client|tail|DTail 1.0.0 +CLIENT|workstation01|INFO|Initiating base client +CLIENT|workstation01|INFO|Added SSH Agent to list of auth methods +CLIENT|workstation01|INFO|Deduped server list|1|1 +CLIENT|workstation01|WARN|Encountered unknown host|{serv-002.lan.example.org:2222 0xc000146450 0xc00014a2f0 [serv-002.lan.example.org]:2222 ssh-rsa AAAA.... +CLIENT|workstation01|WARN|Encountered unknown host|{serv-001.lan.example.org:2222 0xc0001ff450 0xc00ee4a2f0 [serv-001.lan.example.org]:2222 ssh-rsa AAAA.... +Encountered 2 unknown hosts: 'serv-002.lan.example.org:2222 serv-001.lan.example.org:2222' +Do you want to trust these hosts?? (y=yes,a=all,n=no,d=details): y +CLIENT|workstation01|INFO|Added hosts to known hosts file|~/.ssh/known_hosts +CLIENT|workstation01|INFO|stats|connected=1/1(100%)|new=1|rate=0.20/s|throttle=0|cpus/goroutines=8/17 +CLIENT|workstation01|INFO|stats|connected=1/1(100%)|new=0|rate=0.00/s|throttle=0|cpus/goroutines=8/17 +CLIENT|workstation01|INFO|stats|connected=1/1(100%)|new=0|rate=0.00/s|throttle=0|cpus/goroutines=8/17 +CLIENT|workstation01|INFO|stats|connected=1/1(100%)|new=0|rate=0.00/s|throttle=0|cpus/goroutines=8/17 +. +. +. +``` + +Have a look [here](examples.md) for more usage examples. diff --git a/fs/catfile.go b/fs/catfile.go new file mode 100644 index 0000000..99f521f --- /dev/null +++ b/fs/catfile.go @@ -0,0 +1,27 @@ +package fs + +import "sync" + +// CatFile is for reading a whole file. +type CatFile struct { + readFile +} + +// NewCatFile returns a new file catter. +func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile { + var mutex sync.Mutex + + return CatFile{ + readFile: readFile{ + filePath: filePath, + stop: make(chan struct{}), + globID: globID, + serverMessages: serverMessages, + retry: false, + canSkipLines: false, + seekEOF: false, + limiter: limiter, + mutex: &mutex, + }, + } +} diff --git a/fs/filereader.go b/fs/filereader.go new file mode 100644 index 0000000..5a08e27 --- /dev/null +++ b/fs/filereader.go @@ -0,0 +1,9 @@ +package fs + +// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file. +type FileReader interface { + Start(lines chan<- LineRead, regex string) error + FilePath() string + Retry() bool + Stop() +} diff --git a/fs/lineread.go b/fs/lineread.go new file mode 100644 index 0000000..7ee558e --- /dev/null +++ b/fs/lineread.go @@ -0,0 +1,28 @@ +package fs + +import ( + "fmt" +) + +// LineRead represents a read log line. +type LineRead struct { + // The content of the log line. + Content []byte + // Until now, how many log lines were processed? + Count uint64 + // Sometimes we produce too many log lines so that the client + // is too slow to process all of them. The server will drop log + // lines if that happens but it will signal to the client how + // many log lines in % could be transmitted to the client. + TransmittedPerc int + GlobID *string +} + +// Return a human readable representation of the followed line. +func (l LineRead) String() string { + return fmt.Sprintf("LineRead(Content:%s,TransmittedPerc:%v,Count:%v,GlobID:%s)", + string(l.Content), + l.TransmittedPerc, + l.Count, + *l.GlobID) +} diff --git a/fs/permissions/permission.go b/fs/permissions/permission.go new file mode 100644 index 0000000..7d242f1 --- /dev/null +++ b/fs/permissions/permission.go @@ -0,0 +1,14 @@ +// +build !linux + +package permissions + +import ( + "dtail/logger" +) + +// ToRead is to check whether user has read permissions to a given file. +func ToRead(user, filePath string) (bool, error) { + // Only implemented for Linux, always expect true + logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform") + return true, nil +} diff --git a/fs/permissions/permission_linux.c b/fs/permissions/permission_linux.c new file mode 100644 index 0000000..cd10525 --- /dev/null +++ b/fs/permissions/permission_linux.c @@ -0,0 +1,395 @@ +#include "permission_linux.h" + +#ifdef DEBUG +void debug_print_checker(struct permission_checker *pc) { + fprintf(stderr, "DEBUG: user_name:%s (%d)\n", + pc->user_name, pc->uid); + + fprintf(stderr, "DEBUG: ngids:%d\n", pc->ngids); + int j; + for (j = 0; j < pc->ngids; j++) { + fprintf(stderr, "DEBUG: %d", pc->gids[j]); + struct group *gr = getgrgid(pc->gids[j]); + if (gr != NULL) + fprintf(stderr, " (%s)", gr->gr_name); + fprintf(stderr, "\n"); + } + + fprintf(stderr, "DEBUG: file_path:%s (%d:%d)\n", + pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); +} +#endif // DEBUG + +int stat_file(struct permission_checker *pc) { + if (stat(pc->file_path, &pc->file_stat) != 0) + return -1; + +#ifdef DEBUG + fprintf(stderr, "DEBUG: File'%s' is owned by '%d:%d'\n", + pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); +#endif + + return 0; +} + +int get_user_uid(struct permission_checker *pc) { + struct passwd *result = NULL; + + size_t bufsize = sysconf(_SC_GETPW_R_SIZE_MAX); + if (bufsize == -1) + bufsize = 16384; + + char *buf = malloc(bufsize); + if (buf == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unabel to allocate bufer while retrieving user '%s'\n", pc->user_name); +#endif + return -1; + } + + int rc = getpwnam_r(pc->user_name, &pc->pw, buf, bufsize, &result); + + if (result == NULL) { +#ifdef DEBUG + if (rc == 0) { + fprintf(stderr, "DEBUG: No user '%s' found\n", pc->user_name); + } else { + fprintf(stderr, "DEBUG: Unknown error while retrieving user '%s'\n", pc->user_name); + } +#endif + + free(buf); + return -1; + } + + pc->uid = pc->pw.pw_uid; + + free(buf); + return 0; +} + +int get_user_groups(struct permission_checker *pc) { + // First assume we are in 10 groups max + pc->ngids = 10; + pc->gids = malloc(pc->ngids * sizeof(gid_t)); + + if (pc->gids == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to allocate space for gids."); +#endif + return -1; + } + + // Try so many times to load group list until it fits into group array. + while (getgrouplist(pc->user_name, pc->pw.pw_gid, pc->gids, &pc->ngids) == -1) { + // Too many groups, enlarge group array and try again + int newngids = pc->ngids + 100; + size_t newsize = newngids * sizeof(gid_t); + + if (SIZE_MAX / newngids < sizeof(gid_t)) { + // Overflow +#ifdef DEBUG + fprintf(stderr, "DEBUG: Overflow detected."); +#endif + return -1; + } + + gid_t *newgids = realloc(pc->gids, newsize); + if (newgids == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to allocate space for gids."); +#endif + free(pc->gids); + return -1; + } + + pc->gids = newgids; + pc->ngids = newngids; + } + + return 0; +} + +int is_member_of_group(struct permission_checker *pc, gid_t gid) { + int j; + for (j = 0; j < pc->ngids; j++) + if (pc->gids[j] == gid) + return 1; + return 0; +} + +int check_acl_uid_matches(uid_t uid, acl_entry_t entry) { + int ret = -1; + uid_t *acl_uid = acl_get_qualifier(entry); + if (acl_uid == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); +#endif + return -1; + } + + ret = *acl_uid == uid ? 0 : -1; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL user match?: %d <=> %d: %d\n", *acl_uid, uid, ret); +#endif + acl_free(acl_uid); + return ret; +} + +int check_acl_gid_matches(gid_t *gids, int ngids, acl_entry_t entry) { + int ret = -1; + gid_t *acl_gid = acl_get_qualifier(entry); + if (acl_gid == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); +#endif + return -1; + } + + int j; + for (j = 0; j < ngids; j++) { + if (*acl_gid == gids[j]) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User is in group %d", *acl_gid); +#endif + ret = 0; + break; + } + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL group match?: %d <=> ...: %d\n", *acl_gid, ret); +#endif + acl_free(acl_gid); + return ret; +} + +int check_acl(struct permission_checker *pc, const int flag) { + // By default user has no read perm. + int has_read_perm = 0; + + // By default mask tells that there are read perm. However in order to have + // read permissions both, has_read_perm and mask_allows_read_access must be 1! + int mask_allows_read_access = 1; + + acl_type_t type = ACL_TYPE_ACCESS; + acl_t acl = acl_get_file(pc->file_path, type); + + if (acl == NULL) + // Unable to retrieve ACL. + return -1; + + // Walk through each entry of this ACL. + int id; + for (id = ACL_FIRST_ENTRY; ; id = ACL_NEXT_ENTRY) { + acl_entry_t entry; + if (acl_get_entry(acl, id, &entry) != 1) + // No more ACL entries. + break; + + acl_tag_t tag; + if (acl_get_tag_type(entry, &tag) == -1) + // Unable to retrieve ACL tag. + return -1; + + switch (tag) { + case ACL_USER_OBJ: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_USER_OBJ\n"); +#endif + // Ignore this ACL entry if user is not owner of file. + if (pc->uid != pc->file_stat.st_uid) + continue; + break; + case ACL_USER: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_USER\n"); +#endif + // Ignore this ACL entry if uid does not match. + if (check_acl_uid_matches(pc->uid, entry) != 0) + continue; + break; + case ACL_GROUP_OBJ: + if (flag == USER_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_GROUP_OBJ\n"); +#endif + // Ignore ACL entry if user is not in group of file. + if (!is_member_of_group(pc, pc->file_stat.st_gid)) + continue; + break; + case ACL_GROUP: + if (flag == USER_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_GROUP\n"); +#endif + // Ignore ACL entry if user is not in group of entry. + if (check_acl_gid_matches(pc->gids, pc->ngids, entry) != 0) + continue; + break; + case ACL_OTHER: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_OTHER\n"); +#endif + break; + case ACL_MASK: +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_MASK\n"); +#endif + break; + default: +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unknown ACL tag\n"); +#endif + return -1; + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: Retrieving permset\n"); +#endif + acl_permset_t permset; + int permission; + if (acl_get_permset(entry, &permset) == -1) + // Unable to retrieve permset. + return -1; + + if ((permission = acl_get_perm(permset, ACL_READ)) == -1) + // Unable to retrieve permset value. + return -1; + + if (permission == 1 && tag != ACL_MASK) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL says user has permission to read file.\n"); +#endif + has_read_perm = 1; + } else if (permission == 0 && tag == ACL_MASK) { + // Mask says that there are no permissions to read. + mask_allows_read_access = 0; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL mask says no permission to read file.\n"); +#endif + } + } + + if (has_read_perm && mask_allows_read_access) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL end result: User has permission to read file.\n"); +#endif + return 1; + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL end result: User has no permission to read file.\n"); +#endif + return 0; +} + +int check_traditional(struct permission_checker *pc, const int flag) { + mode_t mode = pc->file_stat.st_mode; + uid_t uid = pc->file_stat.st_uid; + gid_t gid = pc->file_stat.st_gid; + + if (flag == USER_CHECK && (mode & S_IROTH)) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Others can read file '%s'\n", + pc->file_path); +#endif + return 1; + + } else if (flag == USER_CHECK && (mode & S_IRUSR) && uid == pc->uid) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User '%s' can read file '%s'\n", + pc->user_name, pc->file_path); +#endif + return 1; + + } else if (flag == GROUP_CHECK && (mode & S_IRGRP) && is_member_of_group(pc, gid)) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User's '%s' group can read file '%s'\n", + pc->user_name, pc->file_path); +#endif + return 1; + } + + return 0; +} + +int permission_to_read(char* user_name, char *file_path) { + int rc = -1; + +#ifdef DEBUG + fprintf(stderr, "DEBUG: User check '%s' for file '%s'\n", user_name, file_path); +#endif + struct permission_checker pc = { + .user_name = user_name, + .gids = NULL, + .ngids = 0, + .file_path = file_path, + }; + + // Gather user's UID. + if ((rc = get_user_uid(&pc)) == -1) + // Could not retrieve UID. + goto cleanup; + + // Gather file owner (user and group). + if ((rc = stat_file(&pc)) == -1) + // Could not stat file. + goto cleanup; + + // Check whether there is an ACL entry which would allow the user + // to read the file. Don't check for any groups yet. The issue with + // groups is that it can be very slow to retrieve the list of groups + // of a specific user when done via a remote LDAP server! + if ((rc = check_acl(&pc, USER_CHECK)) == 1) + // Yes, has permissions. + goto cleanup; + + // Check whether ACLs of file could be retrieved. + if (rc == -1) { + if (errno != ENOTSUP) + // Unknown error. + goto cleanup; + + // File system does not support ACLs. + // Fallback to traditional permissions. + if ((rc = check_traditional(&pc, USER_CHECK)) == 1) + // Yes, has traditional permissions. + goto cleanup; + + if ((rc = get_user_groups(&pc)) == -1) + // Can not retrieve user's groups. + goto cleanup; + + rc = check_traditional(&pc, GROUP_CHECK); + goto cleanup; + } + + if ((rc = get_user_groups(&pc)) == -1) + // Can not retrieve use'r groups. + goto cleanup; + + // Check whether there is an ACL entry which would allow any of the + // user's groups to read the file. + rc = check_acl(&pc, GROUP_CHECK); + +cleanup: +#ifdef DEBUG + debug_print_checker(&pc); +#endif + + if (pc.ngids) + free(pc.gids); + + return rc; +} + +// vim: set tabstop=8 softtabstop=0 expandtab shiftwidth=4 smarttab diff --git a/fs/permissions/permission_linux.go b/fs/permissions/permission_linux.go new file mode 100644 index 0000000..feae729 --- /dev/null +++ b/fs/permissions/permission_linux.go @@ -0,0 +1,33 @@ +package permissions + +/* +#include "permission_linux.h" +#cgo LDFLAGS: -L. -lacl +*/ +import "C" + +import ( + "errors" + "unsafe" +) + +// To check whether user has Linux file system permissions to read a given file. +func ToRead(user, filePath string) (bool, error) { + cUser := C.CString(user) + cFilePath := C.CString(filePath) + + defer C.free(unsafe.Pointer(cUser)) + defer C.free(unsafe.Pointer(cFilePath)) + + cOk, err := C.permission_to_read(cUser, cFilePath) + if cOk == 1 { + return true, nil + } + + if err != nil { + // err contains errno message + return false, err + } + + return false, errors.New("User without permission to read file") +} diff --git a/fs/permissions/permission_linux.h b/fs/permissions/permission_linux.h new file mode 100644 index 0000000..a2c266e --- /dev/null +++ b/fs/permissions/permission_linux.h @@ -0,0 +1,60 @@ +#ifndef PERMISSION_LINUX_H +#define PERMISSION_LINUX_H + +#include <acl/libacl.h> +#include <errno.h> +#include <grp.h> +#include <pwd.h> +#include <stdio.h> +#include <stdint.h> +#include <stdlib.h> +#include <sys/acl.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +//#define DEBUG +#define USER_CHECK 0 +#define GROUP_CHECK 1 + +struct permission_checker { + char *user_name; + uid_t uid; + gid_t *gids; + int ngids; + char *file_path; + struct stat file_stat; + struct passwd pw; +}; + + +#ifdef DEBUG +// Print out permission_checker struct. +void debug_print_checker(struct permission_checker *pc); +#endif + +// Stat a given file to retrieve traditional UNIX permissions. +int stat_file(struct permission_checker *pc); + +// Retrieve UID of user. +int get_user_uid(struct permission_checker *pc); + +// Retrieve all groups of the user. +int get_user_groups(struct permission_checker *pc); + +// Check whether user is member of a group or not. +int is_member_of_group(struct permission_checker *pc, gid_t gid); + +// Check whether user can read file according Linux ACLs. +// As flag use either USER_CHECK or GROUP_CHECK. +int check_acl(struct permission_checker *pc, const int flag); + +// Check whether user has permissions to read file according traditional +// UNIX permissions. As flag use either USER_CHECK or GROUP_CHECK. +int check_traditional(struct permission_checker *pc, const int flag); + +// Returns 1 if user has permission to read file. +// Returns <0 on error and returns 0 if no permissions. +int permission_to_read(char* user, char *file_path); + +#endif // PERMISSION_LINUX_H diff --git a/fs/permissions/permission_test.go b/fs/permissions/permission_test.go new file mode 100644 index 0000000..d415ac2 --- /dev/null +++ b/fs/permissions/permission_test.go @@ -0,0 +1,112 @@ +// +build linux + +package permissions + +import ( + "os" + "os/exec" + "os/user" + "strings" + "testing" +) + +const ( + setfacl string = "/usr/bin/setfacl" + file string = "/tmp/acltest" +) + +func TestLinuxACL(t *testing.T) { + setfacl := "/usr/bin/setfacl" + file := "/tmp/acltest" + + // Delete file if it exists. + if _, err := os.Stat(file); err == nil { + os.Remove(file) + } + + f, err := os.Create(file) + if err != nil { + t.Errorf("%v", err) + } + defer func() { + f.Close() + //os.Remove(file) + }() + + user, err := user.Current() + if err != nil { + t.Errorf("Unable to retrieve current user: %v", err) + } + + // Test 1: Remove all permissions and perform a permission check + cmd := exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } + + // Test 2: Add read permission to file owner + cmd = exec.Command(setfacl, "-b", "-m", "u::r--,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 3: Add read permission to file group + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::r--,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 4: Add read permission to others + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 5: Remove read permission from mask + cmd = exec.Command(setfacl, "-m", "m::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } + cmd = exec.Command(setfacl, "-m", "m::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + + // Test 6: Add read permission to specific group + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g:"+user.Username+":r--,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file for user %v: %v", user.Username, err) + } + + // Test 7: Remove all permissions but mask + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + cmd = exec.Command(setfacl, "-m", "m::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } +} diff --git a/fs/readfile.go b/fs/readfile.go new file mode 100644 index 0000000..375378b --- /dev/null +++ b/fs/readfile.go @@ -0,0 +1,318 @@ +package fs + +import ( + "bufio" + "compress/gzip" + "dtail/logger" + "errors" + "io" + "os" + "regexp" + "strings" + "sync" + "time" + + "github.com/DataDog/zstd" +) + +// Used to tail and filter a local log file. +type readFile struct { + // Various statistics (e.g. regex hit percentage, transfer percentage). + stats + // Path of log file to tail. + filePath string + // Only consider all log lines matching this regular expression. + re *regexp.Regexp + // The glob identifier of the file. + globID string + // Channel to send a server message to the dtail client + serverMessages chan<- string + // Signals to stop tailing the log file. + stop chan struct{} + // Periodically retry reading file. + retry bool + // Can I skip messages when there are too many? + canSkipLines bool + // Seek to the EOF before processing file? + seekEOF bool + // Mutex to control the stopping of the file + mutex *sync.Mutex + limiter chan struct{} +} + +// FilePath returns the full file path. +func (f readFile) FilePath() string { + return f.filePath +} + +// Retry reading the file on error? +func (f readFile) Retry() bool { + return f.retry +} + +// Start tailing a log file. +func (f readFile) Start(lines chan<- LineRead, regex string) error { + defer func() { + select { + case <-f.limiter: + default: + } + }() + + select { + case f.limiter <- struct{}{}: + default: + select { + case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."): + case <-f.stop: + return nil + } + f.limiter <- struct{}{} + } + + fd, err := os.Open(f.filePath) + if err != nil { + return err + } + defer fd.Close() + + if f.seekEOF { + fd.Seek(0, io.SeekEnd) + } + + rawLines := make(chan []byte, 100) + truncate := make(chan struct{}) + + var wg sync.WaitGroup + wg.Add(1) + + go f.periodicTruncateCheck(truncate) + go f.filter(&wg, rawLines, lines, regex) + + err = f.read(fd, rawLines, truncate) + close(rawLines) + wg.Wait() + + return err +} + +func (f readFile) periodicTruncateCheck(truncate chan struct{}) { + for { + select { + case <-time.After(time.Second * 3): + select { + case truncate <- struct{}{}: + case <-f.stop: + } + case <-f.stop: + return + } + } +} + +// Stop reading file. +func (f readFile) Stop() { + f.mutex.Lock() + defer f.mutex.Unlock() + + select { + case <-f.stop: + return + default: + } + + close(f.stop) +} + +func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { + switch { + case strings.HasSuffix(f.FilePath(), ".gz"): + fallthrough + case strings.HasSuffix(f.FilePath(), ".gzip"): + logger.Info(f.FilePath(), "Detected gzip compression format") + var gzipReader *gzip.Reader + gzipReader, err = gzip.NewReader(fd) + if err != nil { + return + } + reader = bufio.NewReader(gzipReader) + case strings.HasSuffix(f.FilePath(), ".zst"): + logger.Info(f.FilePath(), "Detected zstd compression format") + reader = bufio.NewReader(zstd.NewReader(fd)) + default: + reader = bufio.NewReader(fd) + } + + return +} + +func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error { + reader, err := f.makeReader(fd) + if err != nil { + return err + } + rawLine := make([]byte, 0, 512) + var offset uint64 + + lineLengthThreshold := 1024 * 1024 // 1mb + longLineWarning := false + + for { + select { + case <-truncate: + if isTruncated, err := f.truncated(fd); isTruncated { + return err + } + logger.Info(f.filePath, "Current offset", offset) + + case <-f.stop: + return nil + default: + } + + // Read some bytes (max 4k at once as of go 1.12). isPrefix will + // be set if line does not fit into 4k buffer. + bytes, isPrefix, err := reader.ReadLine() + + if err != nil { + // If EOF, sleep a couple of ms and return with nil error. + // If other error, return with non-nil error. + if err != io.EOF { + return err + } + if !f.seekEOF { + logger.Debug(f.FilePath(), "End of file reached") + return nil + } + time.Sleep(time.Millisecond * 100) + continue + } + + rawLine = append(rawLine, bytes...) + offset += uint64(len(bytes)) + + if !isPrefix { + // last LineRead call returned contend until end of line. + rawLine = append(rawLine, '\n') + select { + case rawLines <- rawLine: + case <-f.stop: + return nil + } + rawLine = make([]byte, 0, 512) + if longLineWarning { + longLineWarning = false + } + continue + } + + // Last LineRead call could not read content until end of line, buffer + // was too small. Determine whether we exceed the max line length we + // want dtail to send to the client at once. Possibly split up log line + // into multiple log lines. + if len(rawLine) >= lineLengthThreshold { + if !longLineWarning { + f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines") + // Only print out one warning per long log line. + longLineWarning = true + } + rawLine = append(rawLine, '\n') + select { + case rawLines <- rawLine: + case <-f.stop: + return nil + } + rawLine = make([]byte, 0, 512) + } + } +} + +// Filter log lines matching a given regular expression. +func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- LineRead, regex string) { + defer wg.Done() + + if regex == "" { + regex = "." + } + + re, err := regexp.Compile(regex) + if err != nil { + logger.Error(regex, "Can't compile regex, using '.' instead", err) + re = regexp.MustCompile(".") + } + f.re = re + + for { + select { + case line, ok := <-rawLines: + f.updatePosition() + if !ok { + return + } + if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok { + select { + case lines <- filteredLine: + case <-f.stop: + return + } + } + } + } +} + +func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bool) { + var read LineRead + + if !f.re.Match(line) { + f.updateLineNotMatched() + f.updateLineNotTransmitted() + return read, false + } + f.updateLineMatched() + + // Can we actually send more messages, channel capacity reached? + if f.canSkipLines && length >= capacity { + f.updateLineNotTransmitted() + return read, false + } + f.updateLineTransmitted() + + read = LineRead{ + Content: line, + GlobID: &f.globID, + Count: f.totalLineCount(), + TransmittedPerc: f.transmittedPerc(), + } + + return read, true +} + +// Check wether log file is truncated. Returns nil if not. +func (f readFile) truncated(fd *os.File) (bool, error) { + logger.Debug(f.filePath, "File truncation check") + + // Can not seek currently open FD. + curPos, err := fd.Seek(0, os.SEEK_CUR) + if err != nil { + return true, err + } + + // Can not open file at original path. + pathFd, err := os.Open(f.filePath) + if err != nil { + return true, err + } + defer pathFd.Close() + + // Can not seek file at original path. + pathPos, err := pathFd.Seek(0, io.SeekEnd) + if err != nil { + return true, err + } + + if curPos > pathPos { + return true, errors.New("File got truncated") + } + + return false, nil +} diff --git a/fs/stats.go b/fs/stats.go new file mode 100644 index 0000000..4121ff7 --- /dev/null +++ b/fs/stats.go @@ -0,0 +1,69 @@ +package fs + +// Used to calculate how many log lines matched the regular expression +// and how many log files could be transmitted from the server to the client. +// Hit and transmit percentage takes only the last 100 log lines into calculation. +type stats struct { + pos int + lineCount uint64 + matched [100]bool + matchCount uint64 + transmitted [100]bool + transmitCount int +} + +// Return the total line count. +func (f *stats) totalLineCount() uint64 { + return f.lineCount +} + +// Calculate the percentage of log lines transmitted to the client. +func (f *stats) transmittedPerc() int { + return int(percentOf(float64(f.matchCount), float64(f.transmitCount))) +} + +// Update bucket position. We only take into consideration the last 100 +// lines for stats. +func (f *stats) updatePosition() { + f.pos = (f.pos + 1) % 100 + f.lineCount++ +} + +// Increment match counter. +func (f *stats) updateLineMatched() { + if !f.matched[f.pos] { + f.matchCount++ + f.matched[f.pos] = true + } +} + +// Increment transmitted counter. +func (f *stats) updateLineTransmitted() { + if !f.transmitted[f.pos] { + f.transmitCount++ + f.transmitted[f.pos] = true + } +} + +// Decrement match counter. +func (f *stats) updateLineNotMatched() { + if f.matched[f.pos] { + f.matchCount-- + f.matched[f.pos] = false + } +} + +// Decrement transmitted counter. +func (f *stats) updateLineNotTransmitted() { + if f.transmitted[f.pos] { + f.transmitCount-- + f.transmitted[f.pos] = false + } +} + +func percentOf(total float64, value float64) float64 { + if total == 0 || total == value { + return 100 + } + return value / (total / 100.0) +} diff --git a/fs/tailfile.go b/fs/tailfile.go new file mode 100644 index 0000000..a19d4e6 --- /dev/null +++ b/fs/tailfile.go @@ -0,0 +1,27 @@ +package fs + +import "sync" + +// TailFile is to tail and filter a log file. +type TailFile struct { + readFile +} + +// NewTailFile returns a new file tailer. +func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile { + var mutex sync.Mutex + + return TailFile{ + readFile: readFile{ + filePath: filePath, + stop: make(chan struct{}), + globID: globID, + serverMessages: serverMessages, + retry: true, + canSkipLines: true, + seekEOF: true, + limiter: limiter, + mutex: &mutex, + }, + } +} @@ -0,0 +1,8 @@ +module dtail + +go 1.13 + +require ( + github.com/DataDog/zstd v1.4.4 + golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 +) @@ -0,0 +1,10 @@ +github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE= +github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc= +golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..33ca911 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,427 @@ +package logger + +import ( + "bufio" + "dtail/color" + "dtail/config" + "fmt" + "os" + "os/signal" + "runtime" + "strings" + "sync" + "syscall" + "time" +) + +const ( + clientStr string = "CLIENT" + serverStr string = "SERVER" + infoStr string = "INFO" + warnStr string = "WARN" + errorStr string = "ERROR" + fatalStr string = "FATAL" + debugStr string = "DEBUG" + traceStr string = "TRACE" +) + +// Synchronise access to logging. +var mutex sync.Mutex + +// File descriptor of log file when logToFile enabled. +var fd *os.File + +// File write buffer of log file when logToFile enabled. +var writer *bufio.Writer + +// File write buffer of stdout when logToStdout enabled. +var stdoutWriter *bufio.Writer + +// Current hostname. +var hostname string + +// Used to detect change of day (create one log file per day0 +var lastDateStr string + +// True if log in server mode, false if log in client mode. +var serverEnable bool + +// Used to make logging non-blocking. +var logBufCh chan buf +var stdoutBufCh chan string + +// Stdout channel, required to pause output +var pauseCh chan struct{} +var resumeCh chan struct{} + +// Tell the logger that we are done, program shuts down +var stop chan struct{} +var stdoutFlushed chan struct{} + +// Tell the logger about logrotation +var rotateCh chan os.Signal + +// LogMode allows to specify the verbosity of logging. +type LogMode int + +// Possible log modes. +const ( + NormalMode LogMode = iota + DebugMode LogMode = iota + SilentMode LogMode = iota + TraceMode LogMode = iota + NothingMode LogMode = iota +) + +// Mode is the current log mode in use. +var Mode LogMode + +// LogStrategy allows to specify a log rotation strategy. +type LogStrategy int + +// Possible log strategies. +const ( + NormalStrategy LogStrategy = iota + DailyStrategy LogStrategy = iota + StdoutStrategy LogStrategy = iota +) + +// Strategy is the current log strattegy used. +var Strategy LogStrategy + +// Enables logging to stdout. +var logToStdout bool + +// Enables logging to file. +var logToFile bool + +// Helper type to make logging non-blocking. +type buf struct { + time time.Time + message string +} + +// Init logging. +func Init(myServerEnable bool, mode LogMode, strategy LogStrategy) { + stdoutWriter = bufio.NewWriter(os.Stdout) + + serverEnable = myServerEnable + Mode = mode + Strategy = strategy + + if Mode == NothingMode { + return + } + + switch Strategy { + case DailyStrategy: + _, err := os.Stat(config.Common.LogDir) + logToFile = !os.IsNotExist(err) + logToStdout = !serverEnable || Mode == DebugMode || Mode == TraceMode + case StdoutStrategy: + fallthrough + default: + logToFile = false + logToStdout = true + } + + fqdn, err := os.Hostname() + if err != nil { + panic(err) + } + s := strings.Split(fqdn, ".") + hostname = s[0] + + pauseCh = make(chan struct{}) + resumeCh = make(chan struct{}) + stop = make(chan struct{}) + stdoutFlushed = make(chan struct{}) + + // Setup logrotation + rotateCh = make(chan os.Signal, 1) + signal.Notify(rotateCh, syscall.SIGHUP) + + if logToStdout { + stdoutBufCh = make(chan string, runtime.NumCPU()*100) + go writeToStdout() + } + + if logToFile { + logBufCh = make(chan buf, runtime.NumCPU()*100) + go writeToFile() + } +} + +// Info message logging. +func Info(args ...interface{}) string { + if serverEnable { + return log(serverStr, infoStr, args) + } + + return log(clientStr, infoStr, args) +} + +// Warn message logging. +func Warn(args ...interface{}) string { + if serverEnable { + return log(serverStr, warnStr, args) + } + + return log(clientStr, warnStr, args) +} + +// Error message logging. +func Error(args ...interface{}) string { + if serverEnable { + return log(serverStr, errorStr, args) + } + + return log(clientStr, errorStr, args) +} + +// FatalExit logs an error and exists the process. +func FatalExit(args ...interface{}) { + what := clientStr + if serverEnable { + what = serverStr + } + log(what, fatalStr, args) + + time.Sleep(time.Second) + mutex.Lock() + defer mutex.Unlock() + + closeWriter() + os.Exit(3) +} + +// Debug message logging. +func Debug(args ...interface{}) string { + if Mode == DebugMode || Mode == TraceMode { + if serverEnable { + return log(serverStr, debugStr, args) + } + return log(clientStr, debugStr, args) + } + + return "" +} + +// Trace message logging. +func Trace(args ...interface{}) string { + if Mode == TraceMode { + if serverEnable { + return log(serverStr, traceStr, args) + } + return log(clientStr, traceStr, args) + } + + return "" +} + +// Write log line to buffer and/or log file. +func write(what, severity, message string) { + if logToStdout && (Mode != SilentMode || severity != warnStr) { + line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message) + + if color.Colored { + line = color.Colorfy(line) + } + + stdoutBufCh <- line + } + + if logToFile { + t := time.Now() + timeStr := t.Format("20060102-150405") + logBufCh <- buf{ + time: t, + message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message), + } + } +} + +// Generig log message. +func log(what string, severity string, args []interface{}) string { + if Mode == NothingMode { + return "" + } + + var messages []string + + for _, arg := range args { + switch v := arg.(type) { + case string: + messages = append(messages, v) + case int: + messages = append(messages, fmt.Sprintf("%d", v)) + case error: + messages = append(messages, v.Error()) + default: + messages = append(messages, fmt.Sprintf("%v", v)) + } + } + + message := strings.Join(messages, "|") + write(what, severity, message) + + return fmt.Sprintf("%s|%s", severity, message) +} + +// Raw message logging. +func Raw(message string) { + if Mode == NothingMode { + return + } + + if logToStdout { + if color.Colored { + message = color.Colorfy(message) + } + stdoutBufCh <- message + } + + if logToFile { + logBufCh <- buf{time.Now(), message} + } +} + +// Close log writer (e.g. on change of day). +func closeWriter() { + if writer != nil { + writer.Flush() + fd.Close() + } +} + +// Return the correct log file writer +func fileWriter(dateStr string) *bufio.Writer { + if dateStr != lastDateStr { + return updateFileWriter(dateStr) + } + + // Check for log rotation signal + select { + case <-rotateCh: + stdoutWriter.WriteString("Received signal for logrotation\n") + return updateFileWriter(dateStr) + default: + } + + return writer +} + +// Update log file writer +func updateFileWriter(dateStr string) *bufio.Writer { + // Detected change of day. Close current writer and create a new one. + mutex.Lock() + defer mutex.Unlock() + closeWriter() + + if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) { + if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil { + panic(err) + } + } + + logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr) + newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) + if err != nil { + panic(err) + } + + fd = newFd + writer = bufio.NewWriterSize(fd, 1) + lastDateStr = dateStr + + return writer +} + +func flushStdout() { + defer close(stdoutFlushed) + + for { + select { + case message := <-stdoutBufCh: + stdoutWriter.WriteString(message) + default: + stdoutWriter.Flush() + return + } + } +} + +func writeToStdout() { + for { + select { + case message := <-stdoutBufCh: + stdoutWriter.WriteString(message) + case <-time.After(time.Millisecond * 100): + stdoutWriter.Flush() + case <-pauseCh: + PAUSE: + for { + select { + case <-stdoutBufCh: + case <-resumeCh: + break PAUSE + case <-stop: + return + } + } + case <-stop: + flushStdout() + return + } + } +} + +func writeToFile() { + for { + select { + case buf := <-logBufCh: + dateStr := buf.time.Format("20060102") + w := fileWriter(dateStr) + w.WriteString(buf.message) + case <-pauseCh: + PAUSE: + for { + select { + case <-stdoutBufCh: + case <-resumeCh: + break PAUSE + case <-stop: + return + } + } + case <-stop: + return + } + } +} + +// Pause logging. +func Pause() { + if logToStdout { + pauseCh <- struct{}{} + } + if logToFile { + pauseCh <- struct{}{} + } +} + +// Resume logging (after pausing). +func Resume() { + if logToStdout { + resumeCh <- struct{}{} + } + if logToFile { + resumeCh <- struct{}{} + } +} + +// Stop logging. +func Stop() { + close(stop) + <-stdoutFlushed +} @@ -0,0 +1,250 @@ +package main + +import ( + "dtail/clients" + "dtail/color" + "dtail/config" + "dtail/logger" + "dtail/omode" + "dtail/server" + "dtail/version" + "flag" + "fmt" + "net/http" + _ "net/http" + _ "net/http/pprof" + "os" + "os/user" + "runtime" + "sync" + "time" +) + +// The evil begins here. +func main() { + var cfgFile, modeStr string + var checkHealth bool + var clientServerEnable bool + var connectionsPerCPU int + var debugEnable bool + var discovery string + var displayVersion bool + var files string + var grep, regex string + var maxInitConnections int + var noColor bool + var pprofEnable bool + var queryStr string + var serversStr string + var shutdownAfter int + var silent bool + var sshPort int + var trustAllHosts bool + var userName string + + user, err := user.Current() + if err != nil { + panic(err) + } + + if user.Uid == "0" { + panic("Not allowed to run as UID 0") + } + + if user.Gid == "0" { + panic("Not allowed to run as GID 0") + } + + defaultMode := omode.Default() + serverEnable := defaultMode == omode.Server + clientEnable := !serverEnable + + // Based on the mode we have different default timeouts + var pingTimeoutS int + switch defaultMode { + case omode.CatClient: + fallthrough + case omode.GrepClient: + pingTimeoutS = 60 + case omode.MapClient: + pingTimeoutS = 900 + default: + pingTimeoutS = 5 + } + + flag.BoolVar(&checkHealth, "checkHealth", false, "Only check for server health") + flag.BoolVar(&clientServerEnable, "clientServer", false, "Enable client and server (dev purposes)") + flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") + flag.BoolVar(&displayVersion, "version", false, "Display version") + flag.BoolVar(&noColor, "noColor", false, "Disable ANSII terminal colors") + flag.BoolVar(&pprofEnable, "pprofEnable", false, "Enable pprof server") + flag.BoolVar(&serverEnable, "server", serverEnable, "Start as a DTail server") + flag.BoolVar(&silent, "silent", false, "Reduce output") + flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") + flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") + flag.IntVar(&maxInitConnections, "mic", 20, "Max cpc") + flag.IntVar(&shutdownAfter, "shutdownAfter", 0, "Automatically shutdown after so many seconds") + flag.IntVar(&sshPort, "port", 2222, "SSH server port") + flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") + flag.StringVar(&cfgFile, "cfg", "", "Config file path") + flag.StringVar(&discovery, "discovery", "", "Server discovery method") + flag.StringVar(&files, "files", "", "File(s) to read") + flag.StringVar(&grep, "grep", "", "Regular expression (deprecated)") + flag.StringVar(&modeStr, "mode", defaultMode.String(), "Operating mode (tail, grep, cat, map, server)") + flag.StringVar(&queryStr, "query", "", "Map reduce query") + flag.StringVar(®ex, "regex", "", "Regular expression") + flag.StringVar(&serversStr, "servers", "", "Remote servers to connect") + flag.StringVar(&userName, "user", user.Username, "Your system user name") + + mode := omode.New(modeStr) + + flag.Parse() + + config.Init(cfgFile) + color.Init(!noColor) + + if displayVersion { + fmt.Println(version.PaintedString()) + os.Exit(0) + } + + // Figure out how many SSH sessions can be established concurrently. + if connectionsPerCPU*runtime.NumCPU() < maxInitConnections { + maxInitConnections = connectionsPerCPU * runtime.NumCPU() + } + + // Figure out in which mode I am? Server or client or both (the latter for dev purposes)? + if serverEnable { + clientEnable = false + } + if clientServerEnable { + clientEnable = true + serverEnable = true + } + + // If non-standard port specified, overwrite config + if sshPort != 2222 { + config.Common.SSHPort = sshPort + } + + // Figure out the log level. + var logMode logger.LogMode + switch { + case debugEnable: + logMode = logger.DebugMode + case checkHealth: + logMode = logger.NothingMode + case config.Common.TraceEnable: + logMode = logger.TraceMode + case config.Common.DebugEnable: + logMode = logger.DebugMode + case silent: + logMode = logger.SilentMode + default: + logMode = logger.NormalMode + } + + // Figure out the log strategy. + var logStrategy logger.LogStrategy + switch config.Common.LogStrategy { + case "daily": + logStrategy = logger.DailyStrategy + case "stdout": + fallthrough + default: + logStrategy = logger.StdoutStrategy + } + + logger.Init(serverEnable, logMode, logStrategy) + + // Wait group for shutting down logger. + var wg sync.WaitGroup + if serverEnable { + wg.Add(1) + } + if clientEnable { + wg.Add(1) + } + + logger.Debug("Common config", config.Common) + logger.Debug("Client config", config.Client) + logger.Debug("Server config", config.Server) + + if grep != "" { + logger.Warn("Flag 'grep' is deprecated and may be removed in the future, please use 'regex' instead") + if regex == "" { + regex = grep + } + } + + if checkHealth { + healthClient, _ := clients.NewHealthClient(omode.HealthClient) + os.Exit(healthClient.Start(&wg)) + } + + if shutdownAfter > 0 { + go func() { + defer os.Exit(1) + + logger.Info("Enabling auto shutdown timer", shutdownAfter) + time.Sleep(time.Duration(shutdownAfter) * time.Second) + logger.Info("Auto shutdown timer reached, shutting down now") + }() + } + + if pprofEnable || config.Common.PProfEnable { + bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort) + logger.Info("Starting PProf server", bindAddr) + go http.ListenAndServe(bindAddr, nil) + } + + if serverEnable { + logger.Info("Launching server", mode, version.String()) + sshServer := server.New() + go sshServer.Start(&wg) + } + + if clientEnable { + var client clients.Client + var err error + + logger.Info("Launching client", mode, version.String()) + + args := clients.Args{ + Mode: mode, + ServersStr: serversStr, + Discovery: discovery, + UserName: userName, + Files: files, + Regex: regex, + TrustAllHosts: trustAllHosts, + MaxInitConnections: maxInitConnections, + PingTimeout: pingTimeoutS, + } + + switch mode { + case omode.TailClient: + switch queryStr { + case "": + client, err = clients.NewTailClient(args) + default: + client, err = clients.NewMaprClient(args, queryStr) + } + case omode.GrepClient: + client, err = clients.NewGrepClient(args) + case omode.CatClient: + client, err = clients.NewCatClient(args) + case omode.MapClient: + client, err = clients.NewMaprClient(args, queryStr) + } + + if err != nil { + panic(err) + } + + go client.Start(&wg) + } + + wg.Wait() + logger.Stop() +} diff --git a/mapr/aggregateset.go b/mapr/aggregateset.go new file mode 100644 index 0000000..2096c3c --- /dev/null +++ b/mapr/aggregateset.go @@ -0,0 +1,185 @@ +package mapr + +import ( + "fmt" + "strconv" + "strings" +) + +// AggregateSet represents aggregated key/value pairs from the +// MAPREDUCE log lines. These could be either string values or float +// values. +type AggregateSet struct { + Samples int + FValues map[string]float64 + SValues map[string]string +} + +// NewAggregateSet creates a new empty aggregate set. +func NewAggregateSet() *AggregateSet { + return &AggregateSet{ + FValues: make(map[string]float64), + SValues: make(map[string]string), + } +} + +// String representation of aggregate set. +func (s *AggregateSet) String() string { + return fmt.Sprintf("AggregateSet(Samples:%d,FValues:%v,SValues:%v)", + s.Samples, s.FValues, s.SValues) +} + +// Merge one aggregate set into this one. +func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error { + s.Samples += set.Samples + //logger.Trace("Merge", set) + + for _, sc := range query.Select { + storage := sc.FieldStorage + switch sc.Operation { + case Count: + fallthrough + case Sum: + fallthrough + case Avg: + value := set.FValues[storage] + s.addFloat(storage, value) + case Min: + value := set.FValues[storage] + s.addFloatMin(storage, value) + case Max: + value := set.FValues[storage] + s.addFloatMax(storage, value) + case Last: + value := set.SValues[storage] + s.setString(storage, value) + case Len: + s.setString(storage, set.SValues[storage]) + s.setFloat(storage, set.FValues[storage]) + default: + return fmt.Errorf("Unknown aggregation method '%v'", sc.Operation) + } + } + return nil +} + +// Serialize the aggregate set so it can be sent over the wire. +func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan struct{}) { + //logger.Trace("Serialising mapr.AggregateSet", s) + var sb strings.Builder + + sb.WriteString(groupKey) + sb.WriteString("|") + sb.WriteString(fmt.Sprintf("%d|", s.Samples)) + + for k, v := range s.FValues { + sb.WriteString(k) + sb.WriteString("=") + sb.WriteString(fmt.Sprintf("%v|", v)) + } + + for k, v := range s.SValues { + sb.WriteString(k) + sb.WriteString("=") + sb.WriteString(v) + sb.WriteString("|") + } + + select { + case ch <- sb.String(): + case <-stop: + } +} + +// Add a float value. +func (s *AggregateSet) addFloat(key string, value float64) { + if _, ok := s.FValues[key]; !ok { + s.FValues[key] = value + return + } + s.FValues[key] += value +} + +// Add a float minimum value. +func (s *AggregateSet) addFloatMin(key string, value float64) { + f, ok := s.FValues[key] + if !ok { + s.FValues[key] = value + return + } + + if f > value { + s.FValues[key] = value + } +} + +// Add a float maximum value. +func (s *AggregateSet) addFloatMax(key string, value float64) { + f, ok := s.FValues[key] + if !ok { + s.FValues[key] = value + return + } + + if f < value { + s.FValues[key] = value + } +} + +// Set a string. +func (s *AggregateSet) setString(key, value string) { + s.SValues[key] = value +} + +// Set a float. +func (s *AggregateSet) setFloat(key string, value float64) { + s.FValues[key] = value +} + +// Aggregate data to the aggregate set. +func (s *AggregateSet) Aggregate(key string, agg AggregateOperation, value string, clientAggregation bool) (err error) { + var f float64 + + // First check if we can aggregate anything without converting value to float. + switch agg { + case Count: + if clientAggregation { + f, err = strconv.ParseFloat(value, 64) + if err != nil { + return + } + s.addFloat(key, f) + return + } + s.addFloat(key, 1) + return + case Last: + s.setString(key, value) + return + case Len: + s.setString(key, value) + s.setFloat(key, float64(len(value))) + return + default: + } + + // No, we have to convert to float. + f, err = strconv.ParseFloat(value, 64) + if err != nil { + return + } + + switch agg { + case Sum: + fallthrough + case Avg: + s.addFloat(key, f) + case Min: + s.addFloatMin(key, f) + case Max: + s.addFloatMax(key, f) + default: + err = fmt.Errorf("Unknown aggregation method '%v'", agg) + } + return +} diff --git a/mapr/client/aggregate.go b/mapr/client/aggregate.go new file mode 100644 index 0000000..b9443bc --- /dev/null +++ b/mapr/client/aggregate.go @@ -0,0 +1,100 @@ +package client + +import ( + "dtail/logger" + "dtail/mapr" + "strconv" + "strings" +) + +// Aggregate mapreduce data on the DTail client side. +type Aggregate struct { + // This is the mapr query specified on the command line. + query *mapr.Query + // This represents aggregated data of a single remote server. + group *mapr.GroupSet + // This represents the merged aggregated data of all servers. + globalGroup *mapr.GlobalGroupSet + stop chan struct{} + // The server we aggregate the data for (logging and debugging purposes only) + server string +} + +// NewAggregate create new client aggregator. +func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *Aggregate { + return &Aggregate{ + query: query, + group: mapr.NewGroupSet(), + globalGroup: globalGroup, + stop: make(chan struct{}), + server: server, + } +} + +// Aggregate data from mapr log line into local (and global) group sets. +func (a *Aggregate) Aggregate(parts []string) { + select { + case <-a.stop: + logger.Error("Client aggregator stopped for server, not processing new data", a.server) + return + default: + } + + groupKey := parts[0] + samples, err := strconv.Atoi(parts[1]) + if err != nil { + logger.FatalExit(parts, err) + } + fields := a.makeFields(parts[2:]) + set := a.group.GetSet(groupKey) + + var addedSamples bool + for _, sc := range a.query.Select { + if val, ok := fields[sc.FieldStorage]; ok { + if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, true); err != nil { + logger.Error(err) + continue + } + addedSamples = true + } + } + if addedSamples { + set.Samples += samples + } + + // Merge data from group into global group. + isMerged, err := a.globalGroup.MergeNoblock(a.query, a.group) + if err != nil { + panic(err) + } + if isMerged { + // Re-init local group (make it empty again). + a.group.InitSet() + } +} + +// Create a map of key-value pairs from a part list such as ["foo=bar", "bar=baz"]. +func (a *Aggregate) makeFields(parts []string) map[string]string { + fields := make(map[string]string, len(parts)) + + for _, part := range parts { + kv := strings.Split(part, "=") + if len(kv) != 2 { + continue + } + fields[kv[0]] = kv[1] + } + + return fields +} + +// Stop the client side mapreduce aggregator. +func (a *Aggregate) Stop() { + logger.Debug("Stopping client mapreduce aggregator") + close(a.stop) + + err := a.globalGroup.Merge(a.query, a.group) + if err != nil { + panic(err) + } +} diff --git a/mapr/globalgroupset.go b/mapr/globalgroupset.go new file mode 100644 index 0000000..cfab506 --- /dev/null +++ b/mapr/globalgroupset.go @@ -0,0 +1,100 @@ +package mapr + +import ( + "fmt" +) + +// GlobalGroupSet is used on the dtail client to merge multiple group sets +// (one group set per remote server) to one single global group set. +type GlobalGroupSet struct { + GroupSet + semaphore chan struct{} +} + +// NewGlobalGroupSet creates a new empty global group set. +func NewGlobalGroupSet() *GlobalGroupSet { + g := GlobalGroupSet{ + semaphore: make(chan struct{}, 1), + } + g.InitSet() + + return &g +} + +// String representation of the global group set. +func (g *GlobalGroupSet) String() string { + return fmt.Sprintf("GlobalGroupSet(%s)", g.GroupSet.String()) +} + +// Merge (blocking) a group set into the global group set. +func (g *GlobalGroupSet) Merge(query *Query, group *GroupSet) error { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + return g.merge(query, group) +} + +// MergeNoblock merges (non-blocking) a group set into the global group set. +func (g *GlobalGroupSet) MergeNoblock(query *Query, group *GroupSet) (bool, error) { + select { + case g.semaphore <- struct{}{}: + err := g.merge(query, group) + <-g.semaphore + return true, err + default: + return false, nil + } +} + +// Merge a group set into the global group set. +func (g *GlobalGroupSet) merge(query *Query, group *GroupSet) error { + for groupKey, set := range group.sets { + s := g.GetSet(groupKey) + if err := s.Merge(query, set); err != nil { + return err + } + } + + return nil +} + +// IsEmpty determines whether the global group set has any data in it. +func (g *GlobalGroupSet) IsEmpty() bool { + return g.NumSets() == 0 +} + +// NumSets determines the number of sets. +func (g *GlobalGroupSet) NumSets() int { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + return len(g.sets) +} + +// SwapOut teturn the underlying group set and create a new empty one, so +// that the global group set is empty again and can aggregate new data. +func (g *GlobalGroupSet) SwapOut() *GroupSet { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + set := &GroupSet{sets: g.sets} + g.InitSet() + + return set +} + +// WriteResult writes the result of a mapreduce aggregation to an outfile. +func (g *GlobalGroupSet) WriteResult(query *Query) error { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + return g.GroupSet.WriteResult(query) +} + +// Result returns the result of the mapreduce aggregation as a string. +func (g *GlobalGroupSet) Result(query *Query) (string, int, error) { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + return g.GroupSet.Result(query) +} diff --git a/mapr/groupset.go b/mapr/groupset.go new file mode 100644 index 0000000..d8f9379 --- /dev/null +++ b/mapr/groupset.go @@ -0,0 +1,178 @@ +package mapr + +import ( + "errors" + "fmt" + "io/ioutil" + "sort" + "strconv" + "strings" +) + +// GroupSet represents a map of aggregate sets. The group sets +// are requierd by the "group by" mapr clause, whereas the +// group set map keys are the values of the "group by" arguments. +// E.g. "group by $cid" would create one aggregate set and one map +// entry per customer id. +type GroupSet struct { + sets map[string]*AggregateSet +} + +// NewGroupSet returns a new empty group set. +func NewGroupSet() *GroupSet { + g := GroupSet{} + g.InitSet() + return &g +} + +// String representation of the group set. +func (g *GroupSet) String() string { + return fmt.Sprintf("GroupSet(%v)", g.sets) +} + +// InitSet makes the group set empty (initialize). +func (g *GroupSet) InitSet() { + g.sets = make(map[string]*AggregateSet) +} + +// GetSet gets a specific aggregate set from the group set. +func (g *GroupSet) GetSet(groupKey string) *AggregateSet { + set, ok := g.sets[groupKey] + if !ok { + set = NewAggregateSet() + g.sets[groupKey] = set + } + return set +} + +// Serialize the group set (e.g. to send it over the wire). +func (g *GroupSet) Serialize(ch chan<- string, stop chan struct{}) { + for groupKey, set := range g.sets { + set.Serialize(groupKey, ch, stop) + } +} + +// Result returns a nicely formated result of the query from the group set. +func (g *GroupSet) Result(query *Query) (string, int, error) { + return g.limitedResult(query, query.Limit, "\t", " ", false) +} + +// WriteResult writes the result to an outfile. +func (g *GroupSet) WriteResult(query *Query) error { + if query.Outfile == "" { + return errors.New("No outfile specified") + } + + // -1: Don't limit the result, include all data sets + result, _, err := g.limitedResult(query, -1, "", ",", true) + if err != nil { + return err + } + + return ioutil.WriteFile(query.Outfile, []byte(result), 0644) +} + +// Return a nicely formated result of the query from the group set. +func (g *GroupSet) limitedResult(query *Query, limit int, lineStarter, fieldSeparator string, addHeader bool) (string, int, error) { + type result struct { + groupKey string + resultStr string + orderBy float64 + } + + var resultSlice []result + + for groupKey, set := range g.sets { + var sb strings.Builder + r := result{groupKey: groupKey} + + lastIndex := len(query.Select) - 1 + for i, sc := range query.Select { + storage := sc.FieldStorage + orderByThis := storage == query.OrderBy + + switch sc.Operation { + case Count: + value := set.FValues[storage] + sb.WriteString(fmt.Sprintf("%d", int(value))) + if orderByThis { + r.orderBy = value + } + case Len: + fallthrough + case Sum: + fallthrough + case Min: + fallthrough + case Max: + value := set.FValues[storage] + sb.WriteString(fmt.Sprintf("%f", value)) + if orderByThis { + r.orderBy = value + } + case Last: + value := set.SValues[storage] + if orderByThis { + f, err := strconv.ParseFloat(value, 64) + if err == nil { + r.orderBy = f + } + } + sb.WriteString(value) + case Avg: + value := set.FValues[storage] / float64(set.Samples) + sb.WriteString(fmt.Sprintf("%f", value)) + if orderByThis { + r.orderBy = value + } + default: + return "", 0, fmt.Errorf("Unknown aggregation method '%v'", sc.Operation) + } + if i != lastIndex { + sb.WriteString(fieldSeparator) + } + } + + r.resultStr = sb.String() + resultSlice = append(resultSlice, r) + } + + if query.OrderBy != "" { + if query.ReverseOrder { + sort.SliceStable(resultSlice, func(i, j int) bool { + return resultSlice[i].orderBy < resultSlice[j].orderBy + }) + } else { + sort.SliceStable(resultSlice, func(i, j int) bool { + return resultSlice[i].orderBy > resultSlice[j].orderBy + }) + } + } + + var sb strings.Builder + + // Write header first + if addHeader { + lastIndex := len(query.Select) - 1 + sb.WriteString(lineStarter) + for i, sc := range query.Select { + sb.WriteString(sc.FieldStorage) + if i != lastIndex { + sb.WriteString(fieldSeparator) + } + } + sb.WriteString("\n") + } + + // And now write the data + for i, r := range resultSlice { + if i == limit { + break + } + sb.WriteString(lineStarter) + sb.WriteString(r.resultStr) + sb.WriteString("\n") + } + + return sb.String(), len(resultSlice), nil +} diff --git a/mapr/logformat/default.go b/mapr/logformat/default.go new file mode 100644 index 0000000..f0df5bc --- /dev/null +++ b/mapr/logformat/default.go @@ -0,0 +1,23 @@ +package logformat + +import ( + "errors" + "strings" +) + +// MakeFieldsDEFAULT is the default log file mapreduce parser. +func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) { + fields := make(map[string]string, 20) + splitted := strings.Split(maprLine, "|") + + fields["$hostname"] = p.hostname + + for _, kv := range splitted { + keyAndValue := strings.SplitN(kv, "=", 2) + if len(keyAndValue) != 2 { + return fields, errors.New("Error parsing mapr token: " + kv) + } + fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1] + } + return fields, nil +} diff --git a/mapr/logformat/default_test.go b/mapr/logformat/default_test.go new file mode 100644 index 0000000..a3c47fb --- /dev/null +++ b/mapr/logformat/default_test.go @@ -0,0 +1,35 @@ +package logformat + +import ( + "testing" +) + +func TestDefaultLogFormat(t *testing.T) { + parser, err := NewParser("default") + if err != nil { + t.Errorf("Unable to create parser: %s", err.Error()) + } + + fields, err := parser.MakeFields("foo=bar|baz=bay") + + if err != nil { + t.Errorf("Unable to parse: %s", err.Error()) + } + + if bar, ok := fields["foo"]; !ok { + t.Errorf("Expected field 'foo', but no such field there\n") + } else if bar != "bar" { + t.Errorf("Expected 'bar' stored in field 'foo', but got '%s'\n", bar) + } + + if bay, ok := fields["baz"]; !ok { + t.Errorf("Expected field 'baz', but no such field there\n") + } else if bay != "bay" { + t.Errorf("Expected 'bay' stored in field 'baz', but got '%s'\n", bay) + } + + fields, err = parser.MakeFields("foo=bar|bazbay") + if err == nil { + t.Errorf("Expected error but didn't: %s", err.Error()) + } +} diff --git a/mapr/logformat/parser.go b/mapr/logformat/parser.go new file mode 100644 index 0000000..b7c8c5c --- /dev/null +++ b/mapr/logformat/parser.go @@ -0,0 +1,75 @@ +package logformat + +import ( + "dtail/logger" + "errors" + "fmt" + "os" + "reflect" + "strings" +) + +// Parser is used to parse the mapreduce information from the server log files. +type Parser struct { + hostname string + logFormatName string + makeFieldsFunc reflect.Value + makeFieldsReceiver reflect.Value +} + +// NewParser returns a new log parser. +func NewParser(logFormatName string) (*Parser, error) { + hostname, err := os.Hostname() + + if err != nil { + return nil, err + } + + p := Parser{ + hostname: hostname, + } + + err = p.reflectLogFormat(logFormatName) + if err != nil { + return nil, err + } + + return &p, nil +} + +// The aim of this is that everyone can plug in their own mapr log format +// parsing method to DTail. Just add a method MakeFieldsMODULENAME to type +// Parser. Whereas MODULENAME must be a upeprcase string. +func (p *Parser) reflectLogFormat(logFormatName string) error { + methodName := fmt.Sprintf("MakeFields%s", strings.ToUpper(logFormatName)) + + rt := reflect.TypeOf(p) + method, ok := rt.MethodByName(methodName) + if !ok { + return errors.New("No such mapr log format module: " + methodName) + } + + p.makeFieldsFunc = method.Func + p.makeFieldsReceiver = reflect.ValueOf(p) + + return nil +} + +// MakeFields is for returning the fields from a given log line. +func (p *Parser) MakeFields(maprLine string) (fields map[string]string, err error) { + inputValues := []reflect.Value{p.makeFieldsReceiver, reflect.ValueOf(maprLine)} + returnValues := p.makeFieldsFunc.Call(inputValues) + + errInterface := returnValues[1].Interface() + + if errInterface == nil { + fields, err = returnValues[0].Interface().(map[string]string), nil + logger.Trace("parser.MakeFields", fields, err) + return + } + + fields, err = returnValues[0].Interface().(map[string]string), errInterface.(error) + logger.Trace("parser.MakeFields", fields, err) + + return +} diff --git a/mapr/query.go b/mapr/query.go new file mode 100644 index 0000000..8ed3c67 --- /dev/null +++ b/mapr/query.go @@ -0,0 +1,245 @@ +package mapr + +import ( + "dtail/logger" + "errors" + "fmt" + "strconv" + "strings" + "time" +) + +const ( + invalidQuery string = "Invalid query: " + unexpectedEnd string = "Unexpected end of query" +) + +// Query represents a parsed mapr query. +type Query struct { + Select []selectCondition + Table string + Where []whereCondition + GroupBy []string + OrderBy string + ReverseOrder bool + GroupKey string + Interval time.Duration + Limit int + Outfile string + RawQuery string + tokens []token +} + +func (q Query) String() string { + return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,GroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v)", + q.Select, + q.Table, + q.Where, + q.GroupBy, + q.GroupKey, + q.OrderBy, + q.ReverseOrder, + q.Interval, + q.Limit, + q.Outfile, + q.RawQuery, + q.tokens) +} + +// NewQuery returns a new mapreduce query. +func NewQuery(queryStr string) (*Query, error) { + if queryStr == "" { + return nil, nil + } + + tokens := tokenize(queryStr) + + q := Query{ + RawQuery: queryStr, + tokens: tokens, + Interval: time.Second * 5, + Limit: -1, + } + + err := q.parse(tokens) + + logger.Debug(q) + return &q, err +} + +func (q *Query) parse(tokens []token) error { + var found []token + var err error + + for tokens != nil && len(tokens) > 0 { + switch strings.ToLower(tokens[0].str) { + case "select": + tokens, found = tokensConsume(tokens[1:]) + q.Select, err = makeSelectConditions(found) + if err != nil { + return err + } + case "from": + tokens, found = tokensConsume(tokens[1:]) + if len(found) > 0 { + q.Table = strings.ToUpper(found[0].str) + } + case "where": + tokens, found = tokensConsume(tokens[1:]) + if q.Where, err = makeWhereConditions(found); err != nil { + return err + } + case "group": + tokens = tokensConsumeOptional(tokens[1:], "by") + if tokens == nil || len(tokens) < 1 { + return errors.New(invalidQuery + unexpectedEnd) + } + tokens, q.GroupBy = tokensConsumeStr(tokens) + q.GroupKey = strings.Join(q.GroupBy, ",") + case "rorder": + tokens = tokensConsumeOptional(tokens[1:], "by") + if tokens == nil || len(tokens) < 1 { + return errors.New(invalidQuery + unexpectedEnd) + } + tokens, found = tokensConsume(tokens) + if len(found) == 0 { + return errors.New(invalidQuery + unexpectedEnd) + } + q.OrderBy = found[0].str + q.ReverseOrder = true + case "order": + tokens = tokensConsumeOptional(tokens[1:], "by") + if tokens == nil || len(tokens) < 1 { + return errors.New(invalidQuery + unexpectedEnd) + } + tokens, found = tokensConsume(tokens) + if len(found) == 0 { + return errors.New(invalidQuery + unexpectedEnd) + } + q.OrderBy = found[0].str + case "interval": + tokens, found = tokensConsume(tokens[1:]) + if len(found) > 0 { + i, err := strconv.Atoi(found[0].str) + if err != nil { + return errors.New(invalidQuery + err.Error()) + } + q.Interval = time.Second * time.Duration(i) + } + case "limit": + tokens, found = tokensConsume(tokens[1:]) + if len(found) == 0 { + return errors.New(invalidQuery + unexpectedEnd) + } + i, err := strconv.Atoi(found[0].str) + if err != nil { + return errors.New(invalidQuery + err.Error()) + } + q.Limit = i + case "outfile": + tokens, found = tokensConsume(tokens[1:]) + if len(found) == 0 { + return errors.New(invalidQuery + unexpectedEnd) + } + q.Outfile = found[0].str + default: + return errors.New(invalidQuery + "Unexpected keyword " + tokens[0].str) + } + } + + if q.Table == "" { + return errors.New(invalidQuery + "Empty table specified in 'from' clause") + } + if len(q.Select) < 1 { + return errors.New(invalidQuery + "Expected at least one field in 'select' clause but got none") + } + if len(q.GroupBy) == 0 { + field := q.Select[0].Field + q.GroupBy = append(q.GroupBy, field) + } + + if q.OrderBy != "" { + var orderFieldIsValid bool + for _, sc := range q.Select { + if q.OrderBy == sc.FieldStorage { + orderFieldIsValid = true + break + } + } + if !orderFieldIsValid { + return errors.New(invalidQuery + fmt.Sprintf("Can not '(r)order by' '%s', must be present in 'select' clause", q.OrderBy)) + } + } + + return nil +} + +// WhereClause interprets the where clause of the mapreduce query. +func (q *Query) WhereClause(fields map[string]string) bool { + floatValue := func(str string, float float64, t whereType) (float64, bool) { + switch t { + case Float: + return float, true + case Field: + value, ok := fields[str] + if !ok { + return 0, false + } + f, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0, false + } + return f, true + default: + logger.Error("Unexpected argument in 'where' clause", str, float, t) + return 0, false + } + } + + stringValue := func(str string, t whereType) (string, bool) { + switch t { + case Field: + value, ok := fields[str] + if !ok { + return str, false + } + return value, true + case String: + return str, true + default: + logger.Error("Unexpected argument in 'where' clause", str, t) + return str, false + } + } + + for _, wc := range q.Where { + var ok bool + + if wc.Operation > FloatOperation { + var lValue, rValue float64 + if lValue, ok = floatValue(wc.lString, wc.lFloat, wc.lType); !ok { + return false + } + if rValue, ok = floatValue(wc.rString, wc.rFloat, wc.rType); !ok { + return false + } + if ok = wc.floatClause(lValue, rValue); !ok { + return false + } + continue + } + + var lValue, rValue string + if lValue, ok = stringValue(wc.lString, wc.lType); !ok { + return false + } + if rValue, ok = stringValue(wc.rString, wc.rType); !ok { + return false + } + if ok = wc.stringClause(lValue, rValue); !ok { + return false + } + } + + return true +} diff --git a/mapr/query_test.go b/mapr/query_test.go new file mode 100644 index 0000000..6176461 --- /dev/null +++ b/mapr/query_test.go @@ -0,0 +1,149 @@ +package mapr + +import ( + "testing" + "time" +) + +func TestParseQuerySimple(t *testing.T) { + errorQueries := []string{ + "select", + "select foo", + "select foo from", + "select foo from bar where baz", + "select foo from bar where baz <", + "select foo from bar where baz < 100 bay eq 12 group", + "select foo from bar where baz < 100 bay eq 12 group by foo order by", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit", + } + okQueries := []string{"select foo from bar", + "select foo from bar where", + "select foo from bar where baz < 100 bay eq 12", + "select foo from bar where baz < 100, bay eq 12", + "select foo from bar where baz < 100 and bay eq 12", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\"", + } + + for _, queryStr := range errorQueries { + q, err := NewQuery(queryStr) + if err == nil { + t.Errorf("Expected a parse error: %s\n%v", queryStr, q) + continue + } + } + + for _, queryStr := range okQueries { + _, err := NewQuery(queryStr) + if err != nil { + t.Errorf("%s: %s", err.Error(), queryStr) + continue + } + } +} + +func TestParseQueryDeep(t *testing.T) { + dialects := []string{ + "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23", + "SELECT s1, `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23", + "select s1, `from` count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23", + "sElEct s1, `from` coUnt(s3) from taBle where w1 == 2 aNd w2 eq \"free beer\" Group By g1, g2 order bY count(s3) intervaL 10 LiMiT 23", + "SELECT s1 `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP BY g1 g2 ORDER BY count(s3) INTERVAL 10 LIMIT 23", + "select s1 `from` count(s3) from table where w1 == 2 w2 eq \"free beer\" group g1 g2 order count(s3) interval 10 limit 23", + "limit 23 interval 10 order count(s3) group g1 g2 where w1 == 2 w2 eq \"free beer\" from table select s1 `from` count(s3)", + } + + for _, queryStr := range dialects { + q, err := NewQuery(queryStr) + if err != nil { + t.Errorf("%s: %s", err.Error(), queryStr) + } + + // 'select' clause + if len(q.Select) != 3 { + t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v", q.Select, queryStr, q) + } + + if q.Select[0].Field != "s1" { + t.Errorf("Expected 's1' as first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Field, queryStr, q) + } + if q.Select[0].Operation != Last { + t.Errorf("Expected 'last' as aggregation function of first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Operation, queryStr, q) + } + + if q.Select[1].Field != "from" { + t.Errorf("Expected 'from' as second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Field, queryStr, q) + } + if q.Select[1].Operation != Last { + t.Errorf("Expected 'last' as aggregation function of second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Operation, queryStr, q) + } + + if q.Select[2].Field != "s3" { + t.Errorf("Expected 's3' as third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Field, queryStr, q) + } + if q.Select[2].Operation != Count { + t.Errorf("Expected 'count' as aggregation function of third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Operation, queryStr, q) + } + if q.Select[2].FieldStorage != "count(s3)" { + t.Errorf("Expected 'count(s3)' as third element's storage in 'select' clause but got '%v': %s\n%v", q.Select[2].FieldStorage, queryStr, q) + } + + // 'from' clause + if q.Table != "TABLE" { + t.Errorf("Expected 'TABLE' in 'from' clause but got '%v': %s\n%v", q.Table, queryStr, q) + } + + // 'where' clause + if len(q.Where) != 2 { + t.Errorf("Expected two elements in 'where' clause but got '%v': %s\n%v", q.Where, queryStr, q) + } + if q.Where[0].lString != "w1" { + t.Errorf("Expected w1 as first element in 'where' clause but got '%v': %s\n%v", q.Where[0].lString, queryStr, q) + } + if q.Where[0].Operation != FloatEq { + t.Errorf("Expected FloatEq operation in first 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q) + } + if q.Where[0].rFloat != 2 { + t.Errorf("Expected '2' as float argument in first 'where' condition but got '%v': %s\n%v", q.Where[0].rFloat, queryStr, q) + } + if q.Where[1].lString != "w2" { + t.Errorf("Expected w2 as second element in 'where' clause but got '%v': %s\n%v", q.Where[1].lString, queryStr, q) + } + if q.Where[1].Operation != StringEq { + t.Errorf("Expected StringEq operation in second 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q) + } + if q.Where[1].rString != "free beer" { + t.Errorf("Expected 'free beer' as string argument in second 'where' condition but got '%v': %s\n%v", q.Where[0].rString, queryStr, q) + } + + // 'group by' clause + if len(q.GroupBy) != 2 { + t.Errorf("Expected two elements in 'group by' clause but got '%v': %s\n%v", q.GroupBy, queryStr, q) + } + if q.GroupBy[0] != "g1" { + t.Errorf("Expected 'g1' as first element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[0], queryStr, q) + } + if q.GroupBy[1] != "g2" { + t.Errorf("Expected 'g2' as second element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[1], queryStr, q) + } + if q.GroupKey != "g1,g2" { + t.Errorf("Expected 'g1,g2' as group key in 'group by' clause but got '%v': %s\n%v", q.GroupKey, queryStr, q) + } + + // 'order by' clause + if q.OrderBy != "count(s3)" { + t.Errorf("Expected 'count(s3)' as element in 'order by' clause but got '%v': %s\n%v", q.OrderBy, queryStr, q) + } + + // 'interval' clause + if q.Interval != time.Second*time.Duration(10) { + t.Errorf("Expected '10s' as duration 'interval' clause but got '%v': %s\n%v", q.Interval, queryStr, q) + } + + // 'limit' clause + if q.Limit != 23 { + t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v", q.Limit, queryStr, q) + } + } +} diff --git a/mapr/selectcondition.go b/mapr/selectcondition.go new file mode 100644 index 0000000..1882b7e --- /dev/null +++ b/mapr/selectcondition.go @@ -0,0 +1,96 @@ +package mapr + +import ( + "errors" + "fmt" + "strings" +) + +// AggregateOperation is to specify the aggregate operation type. +type AggregateOperation int + +// Aggregate operation types +const ( + UndefAggregateOperation AggregateOperation = iota + Count AggregateOperation = iota + Sum AggregateOperation = iota + Min AggregateOperation = iota + Max AggregateOperation = iota + Last AggregateOperation = iota + Avg AggregateOperation = iota + Len AggregateOperation = iota +) + +// Represents a parsed "select" clause, used by mapr.Query. +type selectCondition struct { + Field string + FieldStorage string + Operation AggregateOperation +} + +func (sc selectCondition) String() string { + return fmt.Sprintf("selectCondition(Field:%s,FieldStorage:%s,Operation:%v)", + sc.Field, + sc.FieldStorage, + sc.Operation) +} + +func makeSelectConditions(tokens []token) ([]selectCondition, error) { + var sel []selectCondition + + // Parse select aggregation, e.g. sum(foo) + parse := func(token token) (selectCondition, error) { + var sc selectCondition + tokenStr := strings.ToLower(token.str) + + if !strings.Contains(tokenStr, "(") && !strings.Contains(tokenStr, ")") { + sc.Field = tokenStr + sc.FieldStorage = tokenStr + sc.Operation = Last + return sc, nil + } + + a := strings.Split(tokenStr, "(") + if len(a) != 2 { + return sc, errors.New(invalidQuery + "Can't parse 'select' aggregation: " + token.str) + } + agg := a[0] // Aggregation, e.g. 'sum' + + b := strings.Split(a[1], ")") + if len(b) != 2 { + return sc, errors.New(invalidQuery + "Can't parse 'select' field name from aggregation: " + token.str) + } + sc.Field = b[0] // Field name, e.g. 'foo' + sc.FieldStorage = tokenStr // e.g. 'sum(foo)' + + switch agg { + case "count": + sc.Operation = Count + case "sum": + sc.Operation = Sum + case "min": + sc.Operation = Min + case "max": + sc.Operation = Max + case "last": + sc.Operation = Last + case "avg": + sc.Operation = Avg + case "len": + sc.Operation = Len + default: + return sc, errors.New(invalidQuery + "Unknown aggregation in 'select' clause: " + agg) + } + + return sc, nil + } + + for _, token := range tokens { + sc, err := parse(token) + if err != nil { + return nil, err + } + sel = append(sel, sc) + } + return sel, nil +} diff --git a/mapr/server/aggregate.go b/mapr/server/aggregate.go new file mode 100644 index 0000000..316da67 --- /dev/null +++ b/mapr/server/aggregate.go @@ -0,0 +1,170 @@ +package server + +import ( + "dtail/config" + "dtail/fs" + "dtail/logger" + "dtail/mapr" + "dtail/mapr/logformat" + "os" + "strings" + "time" +) + +// Aggregate is for aggregating mapreduce data on the DTail server side. +type Aggregate struct { + // Log lines to process (parsing MAPREDUCE lines). + Lines chan fs.LineRead + // Hostname of the current server (used to populate $hostname field). + hostname string + // Signals to exit goroutine. + stop chan struct{} + // Signals to serialize data. + serialize chan struct{} + // The mapr query + query *mapr.Query + // The mapr log format parser + parser *logformat.Parser +} + +// NewAggregate return a new server side aggregator. +func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) { + query, err := mapr.NewQuery(queryStr) + if err != nil { + return nil, err + } + + fqdn, err := os.Hostname() + if err != nil { + logger.Error(err) + } + s := strings.Split(fqdn, ".") + + logger.Info("Creating mapr log format parser", config.Server.MapreduceLogFormat) + logParser, err := logformat.NewParser(config.Server.MapreduceLogFormat) + if err != nil { + logger.FatalExit("Could not create mapr log format parser", err) + } + + a := Aggregate{ + Lines: make(chan fs.LineRead, 100), + stop: make(chan struct{}), + serialize: make(chan struct{}), + hostname: s[0], + query: query, + parser: logParser, + } + + go a.periodicAggregateTimer() + + fieldsCh := make(chan map[string]string) + go a.readFields(fieldsCh, maprLines) + go a.readLines(fieldsCh) + + return &a, nil +} + +func (a *Aggregate) periodicAggregateTimer() { + for { + select { + case <-time.After(a.query.Interval): + a.Serialize() + case <-a.stop: + return + } + } +} + +func (a *Aggregate) readFields(fieldsCh <-chan map[string]string, maprLines chan<- string) { + group := mapr.NewGroupSet() + + for { + select { + case fields := <-fieldsCh: + a.aggregate(group, fields) + case <-a.serialize: + logger.Info("Serializing mapreduce result") + group.Serialize(maprLines, a.stop) + logger.Info("Done serializing mapreduce result") + group = mapr.NewGroupSet() + case <-a.stop: + return + } + } +} + +func (a *Aggregate) readLines(fieldsCh chan<- map[string]string) { + for { + select { + case line, ok := <-a.Lines: + if !ok { + return + } + + maprLine := strings.TrimSpace(string(line.Content)) + fields, err := a.parser.MakeFields(maprLine) + + if err != nil { + logger.Error(err) + continue + } + if !a.query.WhereClause(fields) { + continue + } + + select { + case fieldsCh <- fields: + case <-a.stop: + } + case <-a.stop: + return + } + } +} + +func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { + //logger.Trace("Aggregating", group, fields) + var sb strings.Builder + + for i, field := range a.query.GroupBy { + if i > 0 { + sb.WriteString(" ") + } + if val, ok := fields[field]; ok { + sb.WriteString(val) + } + } + groupKey := sb.String() + set := group.GetSet(groupKey) + + var addedSample bool + for _, sc := range a.query.Select { + if val, ok := fields[sc.Field]; ok { + if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, false); err != nil { + logger.Error(err) + continue + } + addedSample = true + } + } + + if addedSample { + set.Samples++ + return + } + + logger.Trace("Aggregated data locally without adding new samples") +} + +// Serialize all the aggregated data. +func (a *Aggregate) Serialize() { + select { + case a.serialize <- struct{}{}: + case <-a.stop: + } +} + +// Close the aggregator. +func (a *Aggregate) Close() { + close(a.stop) +} diff --git a/mapr/token.go b/mapr/token.go new file mode 100644 index 0000000..b8be4da --- /dev/null +++ b/mapr/token.go @@ -0,0 +1,108 @@ +package mapr + +import ( + "strings" +) + +var keywords = [...]string{"select", "from", "where", "group", "rorder", "order", "interval", "limit", "outfile"} + +// Represents a parsed token, used to parse the mapr query. +type token struct { + str string + isBareword bool +} + +func (t token) isKeyword() bool { + if !t.isBareword { + return false + } + + for _, keyword := range keywords { + if strings.ToLower(t.str) == keyword { + return true + } + } + return false +} + +func (t token) String() string { + return t.str +} + +func tokenize(queryStr string) []token { + var tokens []token + + for i, part := range strings.Split(queryStr, "\"") { + // Even i, means that it is not a quoted string + if i%2 == 0 { + commasStripped := strings.Replace(part, ",", " ", -1) + for _, tokenStr := range strings.Fields(commasStripped) { + token := token{ + str: tokenStr, + isBareword: true, + } + tokens = append(tokens, token) + } + continue + } + // Add whole quoted string as a token + token := token{ + str: part, + isBareword: false, + } + tokens = append(tokens, token) + } + + return tokens +} + +func tokensConsume(tokens []token) ([]token, []token) { + //logger.Trace("=====================") + var consumed []token + + for i, t := range tokens { + if t.isKeyword() { + //logger.Trace("keyword", t) + return tokens[i:], consumed + } + // strip escapes, such as ` from `foo`, this allows to use keywords as field names + length := len(t.str) + if length == 0 { + continue + } + if t.str[0] == '`' && t.str[length-1] == '`' { + stripped := t.str[1 : length-1] + //logger.Trace("stripped", stripped) + t := token{ + str: stripped, + isBareword: t.isBareword, + } + consumed = append(consumed, t) + continue + } + //logger.Trace("bare", token) + consumed = append(consumed, t) + } + + //logger.Trace("result", consumed) + return nil, consumed +} + +func tokensConsumeStr(tokens []token) ([]token, []string) { + var strings []string + tokens, found := tokensConsume(tokens) + for _, token := range found { + strings = append(strings, token.str) + } + return tokens, strings +} + +func tokensConsumeOptional(tokens []token, optional string) []token { + if len(tokens) < 1 { + return tokens + } + if strings.ToLower(tokens[0].str) == strings.ToLower(optional) { + return tokens[1:] + } + return tokens +} diff --git a/mapr/wherecondition.go b/mapr/wherecondition.go new file mode 100644 index 0000000..515c8ad --- /dev/null +++ b/mapr/wherecondition.go @@ -0,0 +1,193 @@ +package mapr + +import ( + "dtail/logger" + "errors" + "fmt" + "strconv" + "strings" +) + +// QueryOperation determines the mapreduce operation. +type QueryOperation int + +// The possible mapreduce operation.s +const ( + UndefQueryOperation QueryOperation = iota + StringEq QueryOperation = iota + StringNe QueryOperation = iota + StringContains QueryOperation = iota + FloatOperation QueryOperation = iota + FloatEq QueryOperation = iota + FloatNe QueryOperation = iota + FloatLt QueryOperation = iota + FloatLe QueryOperation = iota + FloatGt QueryOperation = iota + FloatGe QueryOperation = iota +) + +type whereType int + +// The possible field types. +const ( + UndefWhereType whereType = iota + Field whereType = iota + String whereType = iota + Float whereType = iota +) + +func (w whereType) String() string { + switch w { + case Field: + return fmt.Sprintf("Field") + case String: + return fmt.Sprintf("String") + case Float: + return fmt.Sprintf("Float") + default: + return fmt.Sprintf("UndefWhereType") + } +} + +// Represent a parsed "where" clause, used by mapr.Query +type whereCondition struct { + lString string + lFloat float64 + lType whereType + + Operation QueryOperation + + rString string + rFloat float64 + rType whereType +} + +func (wc *whereCondition) String() string { + return fmt.Sprintf("whereCondition(Operation:%v,lString:%s,lFloat:%v,lType:%s,rString:%s,rFloat:%v,rType:%s)", + wc.Operation, wc.lString, wc.lFloat, wc.lType.String(), wc.rString, wc.rFloat, wc.rType.String()) +} + +func makeWhereConditions(tokens []token) (where []whereCondition, err error) { + parse := func(tokens []token) (whereCondition, []token, error) { + var wc whereCondition + if len(tokens) < 3 { + return wc, nil, errors.New(invalidQuery + "Not enough arguments in 'where' clause") + } + + whereOp := strings.ToLower(tokens[1].str) + switch whereOp { + case "==": + wc.Operation = FloatEq + case "!=": + wc.Operation = FloatNe + case "<": + wc.Operation = FloatLt + case "<=": + wc.Operation = FloatLe + case "=<": + wc.Operation = FloatLe + case ">": + wc.Operation = FloatGt + case ">=": + wc.Operation = FloatGe + case "=>": + wc.Operation = FloatGe + case "eq": + wc.Operation = StringEq + case "ne": + wc.Operation = StringNe + case "contains": + wc.Operation = StringContains + default: + return wc, nil, errors.New(invalidQuery + "Unknown operation in 'where' clause: " + whereOp) + } + + wc.lString = tokens[0].str + wc.rString = tokens[2].str + + if wc.Operation > FloatOperation { + if !tokens[0].isBareword { + return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's lValue: " + tokens[0].str) + } + if f, err := strconv.ParseFloat(wc.lString, 64); err == nil { + wc.lFloat = f + wc.lType = Float + } else { + wc.lType = Field + } + + if !tokens[2].isBareword { + return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's rValue: " + tokens[2].str) + } + if f, err := strconv.ParseFloat(wc.rString, 64); err == nil { + wc.rFloat = f + wc.rType = Float + } else { + wc.rType = Field + } + return wc, tokens[3:], nil + } + + if tokens[0].isBareword { + wc.lType = Field + } else { + wc.lType = String + } + if tokens[2].isBareword { + wc.rType = Field + } else { + wc.rType = String + } + + return wc, tokens[3:], nil + } + + for len(tokens) > 0 { + var wc whereCondition + var err error + + wc, tokens, err = parse(tokens) + if err != nil { + return nil, err + } + + where = append(where, wc) + tokens = tokensConsumeOptional(tokens, "and") + } + + return +} + +func (wc *whereCondition) floatClause(lValue float64, rValue float64) bool { + switch wc.Operation { + case FloatEq: + return lValue == rValue + case FloatNe: + return lValue != rValue + case FloatLt: + return lValue < rValue + case FloatLe: + return lValue <= rValue + case FloatGt: + return lValue > rValue + case FloatGe: + return lValue >= rValue + default: + logger.Error("Unknown float operation", lValue, wc.Operation, rValue) + } + return false +} + +func (wc *whereCondition) stringClause(lValue string, rValue string) bool { + switch wc.Operation { + case StringEq: + return lValue == rValue + case StringNe: + return lValue != rValue + case StringContains: + return strings.Contains(lValue, rValue) + default: + logger.Error("Unknown string operation", lValue, wc.Operation, rValue) + } + return false +} diff --git a/omode/mode.go b/omode/mode.go new file mode 100644 index 0000000..4bdfc45 --- /dev/null +++ b/omode/mode.go @@ -0,0 +1,81 @@ +package omode + +import ( + "fmt" + "os" + "path" +) + +// Mode used. +type Mode int + +// Possible modes. +const ( + Unknown Mode = iota + Server Mode = iota + TailClient Mode = iota + CatClient Mode = iota + GrepClient Mode = iota + MapClient Mode = iota + HealthClient Mode = iota +) + +// New returns the mode based on the mode string. +func New(modeStr string) Mode { + switch modeStr { + case "dserver": + return Server + case "server": + return Server + + case "dtail": + fallthrough + case "tail": + return TailClient + + case "grep": + fallthrough + case "dgrep": + return GrepClient + + case "cat": + fallthrough + case "dcat": + return CatClient + + case "map": + fallthrough + case "dmap": + return MapClient + + case "health": + return HealthClient + + default: + panic(fmt.Sprintf("Unknown mode: '%s'", modeStr)) + } +} + +// Default mode. +func Default() Mode { + return New(path.Base(os.Args[0])) +} + +func (m Mode) String() string { + switch m { + case Server: + return "server" + case TailClient: + return "tail" + case CatClient: + return "cat" + case GrepClient: + return "grep" + case MapClient: + return "map" + case HealthClient: + return "health" + default: + return "unknown" + } +} diff --git a/prompt/prompt.go b/prompt/prompt.go new file mode 100644 index 0000000..395d4bd --- /dev/null +++ b/prompt/prompt.go @@ -0,0 +1,95 @@ +package prompt + +import ( + "bufio" + "dtail/logger" + "fmt" + "os" + "strings" +) + +// Answer is a user input of a prompt question. +type Answer struct { + // Long version of the expected user input + Long string + // Short version of the expected user input + Short string + // Runs when user input matches + Callback func() + // Runs after Callback and after logging resumes + EndCallback func() + + AskAgain bool +} + +// Prompt used for interactive user input. +type Prompt struct { + question string + answers []Answer +} + +func (p *Prompt) askString() string { + var sb strings.Builder + + sb.WriteString(p.question) + sb.WriteString("? (") + + var ax []string + for _, a := range p.answers { + ax = append(ax, fmt.Sprintf("%s=%s", a.Short, a.Long)) + } + + sb.WriteString(strings.Join(ax, ",")) + sb.WriteString("): ") + + return sb.String() +} + +// New returns a new prompt. +func New(question string) *Prompt { + return &Prompt{question: question} +} + +// Add an answer. +func (p *Prompt) Add(answer Answer) { + p.answers = append(p.answers, answer) +} + +// Ask a question. +func (p *Prompt) Ask() { + reader := bufio.NewReader(os.Stdin) + logger.Pause() + + for { + fmt.Print(p.askString()) + answerStr, _ := reader.ReadString('\n') + + if a, ok := p.answer(strings.TrimSpace(answerStr)); ok { + if a.Callback != nil { + a.Callback() + } + + if !a.AskAgain { + logger.Resume() + if a.EndCallback != nil { + a.EndCallback() + } + return + } + } + } +} + +func (p *Prompt) answer(answerStr string) (*Answer, bool) { + for _, a := range p.answers { + switch answerStr { + case a.Long: + return &a, true + case a.Short: + return &a, true + default: + } + } + + return nil, false +} diff --git a/samples/check_dserver.sh.sample b/samples/check_dserver.sh.sample new file mode 100755 index 0000000..96c96de --- /dev/null +++ b/samples/check_dserver.sh.sample @@ -0,0 +1,4 @@ +#!/bin/bash + +declare -r CONFIG_FILE=/etc/dserver/dtail.json +exec /usr/local/bin/dtail --cfg $CONFIG_FILE --checkHealth diff --git a/samples/dserver.service.sample b/samples/dserver.service.sample new file mode 100644 index 0000000..c5e5e59 --- /dev/null +++ b/samples/dserver.service.sample @@ -0,0 +1,19 @@ +[Unit] +Description=DTail server +After=network.target + +[Service] +Slice=dserver.slice +User=dserver +Group=dserver +ExecStart=/usr/local/bin/dserver -cfg /etc/dserver/dtail.json +WorkingDirectory=/var/run/dserver +NoNewPrivileges=true +PrivateDevices=true +PrivateTmp=true +CPUAccounting=true +MemoryAccounting=true +BlockIOAccounting=true + +[Install] +WantedBy=multi-user.target diff --git a/samples/dtail.json.sample b/samples/dtail.json.sample new file mode 100644 index 0000000..99c0a73 --- /dev/null +++ b/samples/dtail.json.sample @@ -0,0 +1,38 @@ +{ + "Client": {}, + "Server": { + "SSHBindAddress": "0.0.0.0", + "MaxConcurrentCats": 2, + "MaxConcurrentTails": 50, + "MaxConnections": 50, + "MapreduceLogFormat" : "default", + "HostKeyFile" : "cache/ssh_host_key", + "HostKeyBits" : 2048, + "Permissions": { + "Default": [ + "^/.*$" + ], + "Users": { + "pbuetow": [ + "^/.*$" + ], + "jblake": [ + "^/tmp/foo.log$", + "^/.*$", + "!^/tmp/bar.log$" + ] + } + } + }, + "Common": { + "LogDir" : "log", + "CacheDir" : "cache", + "LogStrategy": "daily", + "SSHPort": 2222, + "DebugEnable": false, + "PPerfEnable": false, + "PPerfPort": 6060, + "PPerfBindAddress": "0.0.0.0", + "ExperimentalFeaturesEnable": false + } +} diff --git a/samples/update_key_cache.sh.sample b/samples/update_key_cache.sh.sample new file mode 100644 index 0000000..9817f04 --- /dev/null +++ b/samples/update_key_cache.sh.sample @@ -0,0 +1,33 @@ +#!/bin/bash + +declare -r CACHEDIR=/var/run/dserver/cache +declare -r DSERVER_USER=dserver + +echo "Updating SSH key cache" + +ls /home/ | while read remoteuser; do + keysfile=/home/$remoteuser/.ssh/authorized_keys + + if [ -f $keysfile ]; then + cachefile=$CACHEDIR/$remoteuser.authorized_keys + echo "Caching $keysfile -> $cachefile" + + cp $keysfile $cachefile + chown $DSERVER_USER $cachefile + chmod 600 $cachefile + fi +done + +# Cleanup obsolete public SSH keys +find $CACHEDIR -name \*.authorized_keys -type f | +while read cachefile; do + remoteuser=$(basename $cachefile | cut -d. -f1) + keysfile=/home/$remoteuser/.ssh/authorized_keys + + if [ ! -f $keysfile ]; then + echo "Deleting obsolete cache file $cachefile" + rm $cachefile + fi +done + +echo "All set..." diff --git a/server/handlers/controlhandler.go b/server/handlers/controlhandler.go new file mode 100644 index 0000000..c09eb52 --- /dev/null +++ b/server/handlers/controlhandler.go @@ -0,0 +1,105 @@ +package handlers + +import ( + "dtail/logger" + "dtail/server/user" + "fmt" + "io" + "os" + "strings" +) + +// ControlHandler is used for control functions and health monitoring. +type ControlHandler struct { + serverMessages chan string + pong chan struct{} + stop chan struct{} + payload []byte + hostname string + user *user.User +} + +// NewControlHandler returns a new control handler. +func NewControlHandler(user *user.User) *ControlHandler { + logger.Debug(user, "Creating control handler") + + h := ControlHandler{ + serverMessages: make(chan string, 10), + pong: make(chan struct{}, 10), + stop: make(chan struct{}), + user: user, + } + + fqdn, err := os.Hostname() + if err != nil { + logger.FatalExit(err) + } + + s := strings.Split(fqdn, ".") + h.hostname = s[0] + return &h +} + +// Read is to send data to the client via the Reader interface. +func (h *ControlHandler) Read(p []byte) (n int, err error) { + for { + select { + case message := <-h.serverMessages: + wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) + n = copy(p, wholePayload) + return + case <-h.pong: + logger.Info(h.user, "Sending pong") + n = copy(p, []byte(".pong\n")) + return + case <-h.stop: + return 0, io.EOF + } + } +} + +// Write is to read data to the client via the Writer interface. +func (h *ControlHandler) Write(p []byte) (n int, err error) { + for _, c := range p { + switch c { + case ';': + wholePayload := strings.TrimSpace(string(h.payload)) + h.handleCommand(wholePayload) + h.payload = nil + + default: + h.payload = append(h.payload, c) + } + } + + n = len(p) + return +} + +// Close the control handler. +func (h *ControlHandler) Close() { + close(h.stop) +} + +// Wait returns the handler stop channel. +func (h *ControlHandler) Wait() <-chan struct{} { + return h.stop +} + +func (h *ControlHandler) handleCommand(command string) { + logger.Info(h.user, command) + s := strings.Split(command, " ") + logger.Debug(h.user, "Receiving command", command, s) + + switch s[0] { + case "health": + h.serverMessages <- "OK: DTail SSH Server seems fine" + h.serverMessages <- "done;" + case "ping": + h.pong <- struct{}{} + case "debug": + h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s) + default: + h.serverMessages <- logger.Warn(h.user, "Received unknown command", command, s) + } +} diff --git a/server/handlers/handler.go b/server/handlers/handler.go new file mode 100644 index 0000000..8b1f73e --- /dev/null +++ b/server/handlers/handler.go @@ -0,0 +1,10 @@ +package handlers + +import "io" + +// Handler interface for server side functionality. +type Handler interface { + io.ReadWriter + Close() + Wait() <-chan struct{} +} diff --git a/server/handlers/serverhandler.go b/server/handlers/serverhandler.go new file mode 100644 index 0000000..e2466d4 --- /dev/null +++ b/server/handlers/serverhandler.go @@ -0,0 +1,491 @@ +package handlers + +import ( + "dtail/config" + "dtail/fs" + "dtail/logger" + "dtail/mapr/server" + "dtail/omode" + "dtail/server/user" + "dtail/version" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + commandParseWarning string = "Unable to parse command" +) + +// ServerHandler implements the Reader and Writer interfaces to handle +// the Bi-directional communication between SSH client and server. +// This handler implements the handler of the SSH server. +type ServerHandler struct { + // Local log file readers + fileReaders []fs.FileReader + fileReadersMtx *sync.Mutex + // Channel for read lines. + lines chan fs.LineRead + // Only process log lines matching this regex. + regex string + // Server side mapr log aggregation. + aggregate *server.Aggregate + // Channel of aggregated log lines. + aggregatedMessages chan string + // Channel for server messages to be sent to the client. + serverMessages chan string + // Channel for hidden messages to be sent to the client. + hiddenMessages chan string + // The current payload sent to the client. + payload []byte + // The current server hostname. + hostname string + // The user connecting to dtail. + user *user.User + // To limit the server wide max amount of concurrent cats + catLimiter chan struct{} + // To limit the server wide max amount of concurrent tails + tailLimiter chan struct{} + // Server can tell handler to stop the handler. + stop chan struct{} + // Indicate that client responded to server with "ack stop connection" + ackStopReceived chan struct{} + // Stop timeout. + stopTimeout chan struct{} +} + +// NewServerHandler returns the server handler. +func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler { + logger.Debug(user, "Creating tail handler") + h := ServerHandler{ + fileReadersMtx: &sync.Mutex{}, + lines: make(chan fs.LineRead, 100), + serverMessages: make(chan string, 10), + aggregatedMessages: make(chan string, 10), + hiddenMessages: make(chan string, 10), + ackStopReceived: make(chan struct{}), + stopTimeout: make(chan struct{}), + stop: make(chan struct{}), + catLimiter: catLimiter, + tailLimiter: tailLimiter, + regex: ".", + user: user, + } + + fqdn, err := os.Hostname() + if err != nil { + logger.FatalExit(err) + } + + s := strings.Split(fqdn, ".") + h.hostname = s[0] + + return &h +} + +// Read is to send data to the dtail client via Reader interface. +func (h *ServerHandler) Read(p []byte) (n int, err error) { + for { + select { + case message := <-h.serverMessages: + wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) + n = copy(p, wholePayload) + return + case message := <-h.aggregatedMessages: + data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message) + //logger.Debug("Sending aggregation data", data) + wholePayload := []byte(data) + n = copy(p, wholePayload) + return + case message := <-h.hiddenMessages: + //logger.Debug(h.user, "Sending hidden message", message) + wholePayload := []byte(fmt.Sprintf(".%s\n", message)) + n = copy(p, wholePayload) + return + case line := <-h.lines: + serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|", + h.hostname, line.TransmittedPerc, line.Count, *line.GlobID)) + wholePayload := append(serverInfo, line.Content[:]...) + n = copy(p, wholePayload) + return + case <-time.After(time.Second): + select { + case <-h.stop: + return 0, io.EOF + default: + } + } + } +} + +// Write is to receive data from the dtail client via Writer interface. +func (h *ServerHandler) Write(p []byte) (n int, err error) { + for _, c := range p { + switch c { + case ';': + commandStr := strings.TrimSpace(string(h.payload)) + h.handleCommand(commandStr) + h.payload = nil + default: + h.payload = append(h.payload, c) + } + } + + n = len(p) + return +} + +// Close the server handler. +func (h *ServerHandler) Close() { + h.fileReadersMtx.Lock() + defer h.fileReadersMtx.Unlock() + + for _, reader := range h.fileReaders { + reader.Stop() + } + if h.aggregate != nil { + h.aggregate.Close() + } + + close(h.stop) +} + +func (h *ServerHandler) makeGlobID(path, glob string) string { + var idParts []string + pathParts := strings.Split(path, "/") + + for i, globPart := range strings.Split(glob, "/") { + if strings.Contains(globPart, "*") { + idParts = append(idParts, pathParts[i]) + } + } + + if len(idParts) > 0 { + return strings.Join(idParts, "/") + } + + if len(pathParts) > 0 { + return pathParts[len(pathParts)-1] + } + + h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob)) + return "" +} + +func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) { + retryInterval := time.Second * 5 + glob = filepath.Clean(glob) + + errors := make(chan struct{}) + stop := make(chan struct{}) + defer close(stop) + + go func() { + for { + select { + case <-errors: + h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs")) + case <-stop: + return + case <-h.stop: + return + } + } + }() + + maxRetries := 10 + for { + maxRetries-- + if maxRetries < 0 { + h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)")) + h.internalClose() + return + } + + paths, err := filepath.Glob(glob) + if err != nil { + logger.Warn(h.user, glob, err) + time.Sleep(retryInterval) + continue + } + + if numPaths := len(paths); numPaths == 0 { + logger.Error(h.user, "No such file(s) to read", glob) + select { + case errors <- struct{}{}: + case <-h.stop: + return + default: + } + time.Sleep(retryInterval) + continue + } + + h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors) + break + } +} + +func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) { + var wg sync.WaitGroup + wg.Add(len(paths)) + + read := func(path string, wg *sync.WaitGroup) { + defer wg.Done() + globID := h.makeGlobID(path, glob) + + if !h.user.HasFilePermission(path) { + logger.Error(h.user, "No permission to read file", path, globID) + select { + case errors <- struct{}{}: + default: + } + return + } + + h.startReadingFile(mode, path, globID, regex) + } + + for _, path := range paths { + go read(path, &wg) + } + + wg.Wait() +} + +func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) { + defer h.stopReadingFile(path) + logger.Info(h.user, "Start reading file", path, globID) + + var reader fs.FileReader + switch mode { + case omode.TailClient: + reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) + case omode.GrepClient: + fallthrough + case omode.CatClient: + reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter) + default: + reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) + } + + h.fileReadersMtx.Lock() + h.fileReaders = append(h.fileReaders, reader) + h.fileReadersMtx.Unlock() + + lines := h.lines + // Plugin mappreduce engine + if h.aggregate != nil { + lines = h.aggregate.Lines + } + + for { + if err := reader.Start(lines, regex); err != nil { + logger.Error(h.user, path, globID, err) + } + + select { + case <-h.stop: + return + default: + if !reader.Retry() { + return + } + } + + time.Sleep(time.Second * 2) + logger.Info(path, globID, "Reading file again") + } +} + +func (h *ServerHandler) stopReadingFile(path string) { + logger.Info(h.user, "Stop reading file", path) + + h.fileReadersMtx.Lock() + defer h.fileReadersMtx.Unlock() + + path = filepath.Clean(path) + var fileReaders []fs.FileReader + + for _, reader := range h.fileReaders { + if reader.FilePath() == path { + reader.Stop() + continue + } + fileReaders = append(fileReaders, reader) + } + + if len(fileReaders) == len(h.fileReaders) { + logger.Warn(h.user, "Didn't read file path", path) + return + } + + h.fileReaders = fileReaders + + if len(fileReaders) == 0 { + if h.aggregate != nil { + h.aggregate.Serialize() + } + h.allLinesSent() + } +} + +func (h *ServerHandler) numUnsentMessages() int { + return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages) +} + +func (h *ServerHandler) allLinesSent() { + defer h.internalClose() + + for i := 0; i < 3; i++ { + if h.numUnsentMessages() == 0 { + logger.Debug(h.user, "All lines sent") + return + } + logger.Debug(h.user, "Still lines to be sent") + time.Sleep(time.Second) + } + + logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages()) +} + +// Handler decides to shutdown the connection, not the server itself. +func (h *ServerHandler) internalClose() { + select { + case h.hiddenMessages <- "syn close connection": + case <-time.After(time.Second * 5): + logger.Debug(h.user, "Not waiting for ack close connection") + close(h.stopTimeout) + return + } + + select { + case <-h.Wait(): + case <-time.After(time.Second * 5): + logger.Debug(h.user, "Not waiting for ack close connection") + close(h.stopTimeout) + } +} + +func (h *ServerHandler) handleCommand(commandStr string) { + logger.Info(h.user, commandStr) + + args := strings.Split(commandStr, " ") + argc := len(args) + + logger.Debug(h.user, "Received command", commandStr, argc, args) + + if h.user.Name == config.ControlUser { + h.handleControlCommand(argc, args) + return + } + + h.handleUserCommand(argc, args) +} + +// Special (restricted) set of commands for anonymous ControlUser access. +func (h *ServerHandler) handleControlCommand(argc int, args []string) { + switch args[0] { + case "ping": + h.send(h.hiddenMessages, "pong") + case "debug": + h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) + default: + logger.Warn(h.user, "Received unknown command", argc, args) + } +} + +// Commands for authed users. +func (h *ServerHandler) handleUserCommand(argc int, args []string) { + switch args[0] { + case "grep": + fallthrough + case "cat": + h.handleReadCommand(argc, args, omode.CatClient) + case "tail": + h.handleReadCommand(argc, args, omode.TailClient) + case "map": + h.handleMapCommand(argc, args) + case "ack": + h.handleAckCommand(argc, args) + case "ping": + h.send(h.hiddenMessages, "pong") + case "version": + h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String())) + case "debug": + h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args)) + default: + h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args)) + } +} + +func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) { + regex := "." + if argc >= 4 { + regex = args[3] + } + if argc < 3 { + h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) + return + } + go h.processFileGlob(mode, args[1], regex) +} + +func (h *ServerHandler) handleMapCommand(argc int, args []string) { + if argc < 2 { + h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) + return + } + + queryStr := strings.Join(args[1:], " ") + logger.Info(h.user, "Creating new mapr aggregator", queryStr) + aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr) + + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return + } + + h.aggregate = aggregate +} + +func (h *ServerHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) + return + } + if args[1] == "close" && args[2] == "connection" { + close(h.ackStopReceived) + } +} + +func (h *ServerHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.stop: + } +} + +// Wait (block) until server handler is closed or a timeout has exceeded. +func (h *ServerHandler) Wait() <-chan struct{} { + wait := make(chan struct{}) + + go func() { + select { + case <-h.ackStopReceived: + logger.Debug(h.user, "Closing wait channel due to ACK stop received") + close(wait) + case <-h.stopTimeout: + logger.Debug(h.user, "Closing wait channel due to wait timeout") + close(wait) + case <-h.stop: + logger.Debug(h.user, "Closing wait channel due to stop") + } + }() + + return wait +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..4637458 --- /dev/null +++ b/server/server.go @@ -0,0 +1,213 @@ +package server + +import ( + "dtail/config" + "dtail/logger" + "dtail/server/handlers" + "dtail/server/user" + "dtail/ssh/server" + "dtail/version" + "errors" + "fmt" + "io" + "net" + "sync" + + gossh "golang.org/x/crypto/ssh" +) + +// Server is the main server data structure. +type Server struct { + // Various server statistics counters. + stats stats + // SSH server configuration. + sshServerConfig *gossh.ServerConfig + // To control the max amount of concurrent cats (which can cause a lot of I/O on the server) + catLimiterCh chan struct{} + // To control the max amount of concurrent tails + tailLimiterCh chan struct{} + // Ask to shutdown the server + stop chan struct{} +} + +// New returns a new server. +func New() *Server { + logger.Info("Creating server", version.String()) + + s := Server{ + sshServerConfig: &gossh.ServerConfig{}, + catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), + tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), + stop: make(chan struct{}), + } + + s.sshServerConfig.PasswordCallback = s.controlUserCallback + s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback + + private, err := gossh.ParsePrivateKey(server.PrivateHostKey()) + if err != nil { + logger.FatalExit(err) + } + s.sshServerConfig.AddHostKey(private) + + return &s +} + +// Start the server. +func (s *Server) Start(wg *sync.WaitGroup) int { + defer wg.Done() + logger.Info("Starting server") + + bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) + logger.Info("Binding server", bindAt) + listener, err := net.Listen("tcp", bindAt) + if err != nil { + logger.FatalExit("Failed to open listening TCP socket", err) + } + + go s.stats.periodicLogServerStats(s.stop) + + for { + conn, err := listener.Accept() // Blocking + if err != nil { + logger.Error("Failed to accept incoming connection", err) + continue + } + + if err := s.stats.serverLimitExceeded(); err != nil { + logger.Error(err) + conn.Close() + continue + } + + go s.handleConnection(conn) + } +} + +func (s *Server) handleConnection(conn net.Conn) { + logger.Info("Handling connection") + + sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) + if err != nil { + logger.Error("Something just happened", err) + return + } + + s.stats.incrementConnections() + + go gossh.DiscardRequests(reqs) + for newChannel := range chans { + go s.handleChannel(sshConn, newChannel) + } +} + +func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) { + user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) + logger.Info(user, "Invoking channel handler") + + if newChannel.ChannelType() != "session" { + err := errors.New("Don'w allow other channel types than session") + logger.Error(user, err) + newChannel.Reject(gossh.Prohibited, err.Error()) + return + } + + channel, requests, err := newChannel.Accept() + if err != nil { + logger.Error(user, "Could not accept channel", err) + return + } + + if err := s.handleRequests(sshConn, requests, channel, user); err != nil { + logger.Error(user, err) + sshConn.Close() + } +} + +func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { + logger.Info(user, "Invoking request handler") + + for req := range in { + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + + switch req.Type { + case "shell": + var handler handlers.Handler + switch user.Name { + case config.ControlUser: + handler = handlers.NewControlHandler(user) + default: + handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh) + } + + // Bi-directionally connect SSH stream to SSH handler + brokenPipe1 := make(chan struct{}) + go func() { + defer close(brokenPipe1) + io.Copy(channel, handler) + }() + + brokenPipe2 := make(chan struct{}) + go func() { + defer close(brokenPipe2) + io.Copy(handler, channel) + }() + + // Ensure to close all fd's and stop all goroutines once ssh connection terminated + go func() { + defer s.stats.decrementConnections() + defer handler.Close() + + if err := sshConn.Wait(); err != nil && err != io.EOF { + logger.Error(user, err) + } + logger.Info(user, "Good bye Mister!") + }() + + // Close the underlying ssh socket when server shuts down + go func() { + select { + case <-s.stop: + logger.Debug(user, "Server initiating shutdown on handler") + case <-handler.Wait(): + logger.Debug(user, "Handler initiating shutdown by its own") + case <-brokenPipe1: + logger.Debug(user, "Broken pipe1") + case <-brokenPipe2: + logger.Debug(user, "Broken pipe2") + } + sshConn.Close() + logger.Info(user, "Closed SSH connection") + }() + + // Only serving shell type + req.Reply(true, nil) + + default: + req.Reply(false, nil) + + return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v", + req.Type, payload.Value) + } + } + + return nil +} + +func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) { + user := user.New(c.User(), c.RemoteAddr().String()) + + if user.Name == config.ControlUser && string(authPayload) == config.ControlUser { + logger.Debug(user, "Initiating master control program") + return nil, nil + } + + return nil, fmt.Errorf("Not authorized") +} + +// Stop the server. +func (s *Server) Stop() { + close(s.stop) + s.stats.waitForConnections() +} diff --git a/server/stats.go b/server/stats.go new file mode 100644 index 0000000..01aa121 --- /dev/null +++ b/server/stats.go @@ -0,0 +1,88 @@ +package server + +import ( + "dtail/config" + "dtail/logger" + "fmt" + "runtime" + "sync" + "time" +) + +// Used to collect and display various server stats. +type stats struct { + mutex sync.Mutex + currentConnections int + lifetimeConnections uint64 +} + +func (s *stats) incrementConnections() { + defer s.logServerStats() + + s.mutex.Lock() + s.currentConnections++ + s.lifetimeConnections++ + s.mutex.Unlock() +} + +func (s *stats) decrementConnections() { + defer s.logServerStats() + + s.mutex.Lock() + s.currentConnections-- + s.mutex.Unlock() +} + +func (s *stats) hasConnections() bool { + s.mutex.Lock() + currentConnections := s.currentConnections + s.mutex.Unlock() + + has := currentConnections > 0 + logger.Info("stats", "Server with open connections?", has, currentConnections) + + return has +} + +func (s *stats) logServerStats() { + s.mutex.Lock() + defer s.mutex.Unlock() + + currentConnections := fmt.Sprintf("currentConnections=%d", s.currentConnections) + lifetimeConnections := fmt.Sprintf("lifetimeConnections=%d", s.lifetimeConnections) + goroutines := fmt.Sprintf("goroutines=%d", runtime.NumGoroutine()) + logger.Info("stats", currentConnections, lifetimeConnections, goroutines) +} + +func (s *stats) serverLimitExceeded() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.currentConnections >= config.Server.MaxConnections { + return fmt.Errorf("Exceeded max allowed concurrent connections of %d", config.Server.MaxConnections) + } + + return nil +} + +func (s *stats) periodicLogServerStats(stop <-chan struct{}) { + for { + select { + case <-time.NewTimer(time.Second * 10).C: + s.logServerStats() + case <-stop: + return + } + } +} + +func (s *stats) waitForConnections() { + for { + select { + case <-time.NewTimer(time.Second).C: + if !s.hasConnections() { + return + } + } + } +} diff --git a/server/user/user.go b/server/user/user.go new file mode 100644 index 0000000..405dc55 --- /dev/null +++ b/server/user/user.go @@ -0,0 +1,131 @@ +package user + +import ( + "dtail/config" + "dtail/fs/permissions" + "dtail/logger" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" +) + +const maxLinkDepth int = 100 + +// User represents an end-user which connected to the server via the DTail client. +type User struct { + // The user name. + Name string + // The remote address connected from. + remoteAddress string + // The permissions the user has. + permissions []string +} + +// New returns a new user. +func New(name, remoteAddress string) *User { + return &User{ + Name: name, + remoteAddress: remoteAddress, + } +} + +// String representation of the user. +func (u *User) String() string { + return fmt.Sprintf("%s@%s", u.Name, u.remoteAddress) +} + +// HasFilePermission is used to determine whether user is alowed to read a file. +func (u *User) HasFilePermission(filePath string) (hasPermission bool) { + cleanPath, err := filepath.EvalSymlinks(filePath) + if err != nil { + logger.Error(u, filePath, "Unable to evaluate symlinks", err) + hasPermission = false + return + } + + cleanPath, err = filepath.Abs(cleanPath) + if err != nil { + logger.Error(u, cleanPath, "Unable to make file path absolute", err) + hasPermission = false + return + } + + if cleanPath != filePath { + logger.Info(u, filePath, cleanPath, "Calculated new clean path from original file path (possibly symlink)") + } + + hasPermission, err = u.hasFilePermission(cleanPath) + if err != nil { + logger.Warn(u, cleanPath, err) + } + + return +} + +func (u *User) hasFilePermission(cleanPath string) (bool, error) { + // First check file system Linux/UNIX permission. + if _, err := permissions.ToRead(u.Name, cleanPath); err != nil { + return false, fmt.Errorf("User without OS file system permissions to read file: '%v'", err) + } + logger.Info(u, cleanPath, "User has OS file system permissions to read file") + + // If file system permission is given, also check permissions + // as configured in DTail config file. + if len(u.permissions) == 0 { + p, err := config.ServerUserPermissions(u.Name) + if err != nil { + return false, err + } + u.permissions = p + } + + var hasPermission bool + var err error + + if hasPermission, err = u.iteratePaths(cleanPath); err != nil { + return false, err + } + + // Only allow to follow regular files or symlinks. + info, err := os.Lstat(cleanPath) + if err != nil { + return false, fmt.Errorf("Unable to determine file type: '%v'", err) + } + + if !info.Mode().IsRegular() { + return false, fmt.Errorf("Can only open regular files or follow symlinks") + } + + return hasPermission, nil +} + +func (u *User) iteratePaths(cleanPath string) (bool, error) { + for _, permission := range u.permissions { + var regexStr string + var negate bool + + if strings.HasPrefix(permission, "!") { + regexStr = permission[1:] + negate = true + } + regexStr = permission + negate = false + + re, err := regexp.Compile(regexStr) + if err != nil { + return false, fmt.Errorf("Permission test failed, can't compile regex '%s': '%v'", regexStr, err) + } + + if negate && re.MatchString(cleanPath) { + return false, fmt.Errorf("Permission test failed, matching negative pattern '%s'", permission) + } + + if !negate && re.MatchString(cleanPath) { + logger.Info(u, cleanPath, "Permission test passed partially, matching positive pattern", permission) + } + } + + return true, nil +} diff --git a/ssh/client/authmethods.go b/ssh/client/authmethods.go new file mode 100644 index 0000000..84b7ce3 --- /dev/null +++ b/ssh/client/authmethods.go @@ -0,0 +1,45 @@ +package client + +import ( + "dtail/config" + "dtail/logger" + "dtail/ssh" + "os" + + gossh "golang.org/x/crypto/ssh" +) + +// InitSSHAuthMethods initialises all known SSH auth methods on othe client side. +func InitSSHAuthMethods(trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, *HostKeyCallback) { + var sshAuthMethods []gossh.AuthMethod + + if config.Common.ExperimentalFeaturesEnable { + sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test")) + logger.Info("Added experimental method to list of auth methods") + } + + keyPath := os.Getenv("HOME") + "/.ssh/id_rsa" + if authMethod, err := ssh.PrivateKey(keyPath); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added path to list of auth methods", keyPath) + } + + keyPath = os.Getenv("HOME") + "/.ssh/id_dsa" + if authMethod, err := ssh.PrivateKey(keyPath); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added path to list of auth methods", keyPath) + } + + if authMethod, err := ssh.Agent(); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added SSH Agent to list of auth methods") + } + + knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts" + hostKeyCallback, err := NewHostKeyCallback(knownHostsPath, trustAllHosts, throttleCh) + if err != nil { + logger.FatalExit(knownHostsPath, err) + } + + return sshAuthMethods, hostKeyCallback +} diff --git a/ssh/client/hostkeycallback.go b/ssh/client/hostkeycallback.go new file mode 100644 index 0000000..7279f5e --- /dev/null +++ b/ssh/client/hostkeycallback.go @@ -0,0 +1,285 @@ +package client + +import ( + "bufio" + "dtail/logger" + "dtail/prompt" + "fmt" + "net" + "os" + "strings" + "sync" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +type response int + +const ( + trustHost response = iota + dontTrustHost response = iota +) + +// Represents an unknown host. +type unknownHost struct { + server string + remote net.Addr + key ssh.PublicKey + hostLine string + ipLine string + responseCh chan response +} + +// HostKeyCallback is a wrapper around ssh.KnownHosts so that we can add all +// unknown hosts in a single batch to the known_hosts file. +type HostKeyCallback struct { + knownHostsPath string + unknownCh chan unknownHost + throttleCh chan struct{} + trustAllHostsCh chan struct{} + untrustedHosts map[string]bool + mutex sync.Mutex +} + +// NewHostKeyCallback returns a new wrapper. +func NewHostKeyCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (*HostKeyCallback, error) { + // Ensure file exists + os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + + h := HostKeyCallback{ + knownHostsPath: knownHostsPath, + unknownCh: make(chan unknownHost), + trustAllHostsCh: make(chan struct{}), + throttleCh: throttleCh, + untrustedHosts: make(map[string]bool), + } + + if trustAllHosts { + close(h.trustAllHostsCh) + } + + return &h, nil +} + +// Wrap the host key callback. +func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback { + return func(server string, remote net.Addr, key ssh.PublicKey) error { + // Parse known_hosts file + knownHostsCb, err := knownhosts.New(h.knownHostsPath) + if err != nil { + // Problem parsing it + return err + } + + // Check for valid entry in known_hosts file + err = knownHostsCb(server, remote, key) + if err == nil { + // OK + return nil + } + + // Make sure that interactive user callback does not interfere with + // SSH connection throttler. + <-h.throttleCh + defer func() { h.throttleCh <- struct{}{} }() + + unknown := unknownHost{ + server: server, + remote: remote, + key: key, + hostLine: knownhosts.Line([]string{server}, key), + ipLine: knownhosts.Line([]string{remote.String()}, key), + responseCh: make(chan response), + } + + logger.Warn("Encountered unknown host", unknown) + // Notify user that there is an unknown host + h.unknownCh <- unknown + + // Wait for user input. + switch <-unknown.responseCh { + case trustHost: + // End user acknowledged host key + return nil + case dontTrustHost: + } + + h.mutex.Lock() + defer h.mutex.Unlock() + h.untrustedHosts[server] = true + + return err + } +} + +// PromptAddHosts prompts a question to the user whether unknown hosts should +// be added to the known hosts or not. +func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { + var hosts []unknownHost + + for { + // Check whether there is a unknown host + select { + case unknown := <-h.unknownCh: + hosts = append(hosts, unknown) + // Ask every 50 unknown hosts + if len(hosts) >= 50 { + h.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-time.After(2 * time.Second): + // Or ask when after 2 seconds no new unknown hosts were added. + if len(hosts) > 0 { + h.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-stop: + logger.Debug("Stopping goroutine prompting new hosts...") + return + } + } +} + +func (h *HostKeyCallback) promptAddHosts(hosts []unknownHost) { + var servers []string + + for _, host := range hosts { + servers = append(servers, host.server) + } + + select { + case <-h.trustAllHostsCh: + logger.Warn("Trusting host keys of servers", servers) + h.trustHosts(hosts) + return + default: + } + + question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s", + len(servers), + strings.Join(servers, ","), + "Do you want to trust these hosts?", + ) + + p := prompt.New(question) + + a := prompt.Answer{ + Long: "yes", + Short: "y", + Callback: func() { + h.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "all", + Short: "a", + Callback: func() { + close(h.trustAllHostsCh) + h.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "no", + Short: "n", + Callback: func() { + h.dontTrustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Didn't add hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "details", + Short: "d", + AskAgain: true, + Callback: func() { + for _, unknown := range hosts { + fmt.Println(unknown.hostLine) + fmt.Println(unknown.ipLine) + } + }, + } + p.Add(a) + + p.Ask() +} + +func (h *HostKeyCallback) trustHosts(hosts []unknownHost) { + tmpKnownHostsPath := fmt.Sprintf("%s.tmp", h.knownHostsPath) + + newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) + } + defer newFd.Close() + + // Newly trusted hosts in normalized form + addresses := make(map[string]struct{}) + + // First write to new known hosts file, and keep track of addresses + for _, unknown := range hosts { + unknown.responseCh <- trustHost + + // Add once as [HOSTNAME]:PORT + addresses[knownhosts.Normalize(unknown.server)] = struct{}{} + // And once as [IP]:PORT + addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{} + + newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine)) + newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine)) + } + + // Read old known hosts file, to see which are old and new entries + os.OpenFile(h.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + oldFd, err := os.Open(h.knownHostsPath) + if err != nil { + panic(err) + } + defer oldFd.Close() + + scanner := bufio.NewScanner(oldFd) + + // Now, append all still valid old entries to the new host file + for scanner.Scan() { + line := scanner.Text() + address := strings.SplitN(line, " ", 2)[0] + + if _, ok := addresses[address]; !ok { + newFd.WriteString(fmt.Sprintf("%s\n", line)) + } + } + + // Now, replace old known hosts file + if err := os.Rename(tmpKnownHostsPath, h.knownHostsPath); err != nil { + panic(err) + } +} + +func (h *HostKeyCallback) dontTrustHosts(hosts []unknownHost) { + for _, unknown := range hosts { + unknown.responseCh <- dontTrustHost + } +} + +// Untrusted returns true if the host is not trusted. False otherwise. +func (h *HostKeyCallback) Untrusted(server string) bool { + h.mutex.Lock() + defer h.mutex.Unlock() + _, ok := h.untrustedHosts[server] + + return ok +} diff --git a/ssh/server/hostkey.go b/ssh/server/hostkey.go new file mode 100644 index 0000000..ff1eb82 --- /dev/null +++ b/ssh/server/hostkey.go @@ -0,0 +1,37 @@ +package server + +import ( + "dtail/config" + "dtail/logger" + "dtail/ssh" + "io/ioutil" + "os" +) + +// PrivateHostKey retrieves the private server RSA host key. +func PrivateHostKey() []byte { + hostKeyFile := config.Server.HostKeyFile + _, err := os.Stat(hostKeyFile) + + if os.IsNotExist(err) { + logger.Info("Generating private server RSA host key") + privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits) + + if err != nil { + logger.FatalExit("Failed to generate private server RSA host key", err) + } + + pem := ssh.EncodePrivateKeyToPEM(privateKey) + if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil { + logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err) + } + return pem + } + + logger.Info("Reading private server RSA host key from file", hostKeyFile) + pem, err := ioutil.ReadFile(hostKeyFile) + if err != nil { + logger.FatalExit("Failed to load private server RSA host key", err) + } + return pem +} diff --git a/ssh/server/publickeycallback.go b/ssh/server/publickeycallback.go new file mode 100644 index 0000000..867f639 --- /dev/null +++ b/ssh/server/publickeycallback.go @@ -0,0 +1,61 @@ +package server + +import ( + "dtail/config" + "dtail/logger" + "dtail/server/user" + "fmt" + "io/ioutil" + "os" + osUser "os/user" + + gossh "golang.org/x/crypto/ssh" +) + +// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not. +func PublicKeyCallback(c gossh.ConnMetadata, pubKey gossh.PublicKey) (*gossh.Permissions, error) { + user := user.New(c.User(), c.RemoteAddr().String()) + logger.Info(user, "Incoming authorization") + + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error()) + } + + authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name) + if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) { + user, err := osUser.Lookup(user.Name) + if err != nil { + return nil, fmt.Errorf("Unable to authorize|%s|%s|", user, err.Error()) + } + // Fallback to ~ + authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys" + } + + logger.Info(user, "Reading", authorizedKeysFile) + authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile) + if err != nil { + return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error()) + } + + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := gossh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error()) + } + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + if authorizedKeysMap[string(pubKey.Marshal())] { + logger.Debug("Public key fingerprint", gossh.FingerprintSHA256(pubKey), user) + return &gossh.Permissions{ + Extensions: map[string]string{ + "pubkey-fp": gossh.FingerprintSHA256(pubKey), + }, + }, nil + } + + return nil, fmt.Errorf("Unknown public key|%s", user) +} diff --git a/ssh/ssh.go b/ssh/ssh.go new file mode 100644 index 0000000..6cd28a2 --- /dev/null +++ b/ssh/ssh.go @@ -0,0 +1,112 @@ +package ssh + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "dtail/logger" + "encoding/pem" + "fmt" + "io/ioutil" + "net" + "os" + "syscall" + + gossh "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "golang.org/x/crypto/ssh/terminal" +) + +// GeneratePrivateRSAKey is used by the server to generate its key. +func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, size) + if err != nil { + return nil, err + } + + err = privateKey.Validate() + if err != nil { + return nil, err + } + + return privateKey, nil +} + +// EncodePrivateKeyToPEM is a helper function for converting a key to PEM format. +func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { + derFormat := x509.MarshalPKCS1PrivateKey(privateKey) + + block := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: derFormat, + } + + return pem.EncodeToMemory(&block) +} + +// Agent used for SSH auth. +func Agent() (gossh.AuthMethod, error) { + sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + if err != nil { + return nil, err + } + agentClient := agent.NewClient(sshAgent) + keys, err := agentClient.List() + if err != nil { + return nil, err + } + for i, key := range keys { + logger.Debug("Public key", i, key) + } + return gossh.PublicKeysCallback(agentClient.Signers), nil +} + +// EnterKeyPhrase is required to read phrase protected private keys. +func EnterKeyPhrase(keyFile string) []byte { + fmt.Printf("Enter phrase for key %s: ", keyFile) + phrase, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + panic(err) + } + fmt.Printf("%s\n", string(phrase)) + return phrase +} + +// KeyFile returns the key as a SSH auth method. +func KeyFile(keyFile string) (gossh.AuthMethod, error) { + buffer, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, err + } + + key, err := gossh.ParsePrivateKey(buffer) + if err != nil { + return nil, err + } + + // Key phrase support disabled as password will be printed to stdout! + /* + if err == nil { + return gossh.PublicKeys(key), nil + } + + keyPhrase := EnterKeyPhrase(keyFile) + key, err = gossh.ParsePrivateKeyWithPassphrase(buffer, keyPhrase) + if err != nil { + return nil, err + } + */ + + return gossh.PublicKeys(key), nil +} + +// PrivateKey returns the private key as a SSH auth method. +func PrivateKey(keyFile string) (gossh.AuthMethod, error) { + signer, err := KeyFile(keyFile) + if err != nil { + logger.Debug(keyFile, err) + return nil, err + } + return gossh.AuthMethod(signer), nil +} diff --git a/version/version.go b/version/version.go new file mode 100644 index 0000000..3c3b5e9 --- /dev/null +++ b/version/version.go @@ -0,0 +1,32 @@ +package version + +import ( + "dtail/color" + "fmt" +) + +// Name of DTail. +const Name = "DTail" + +// Version of DTail. +const Version = "1.0.0" + +// Additional information. +const Additional = "" + +// String representation of the DTail version. +func String() string { + return fmt.Sprintf("%s v%v %s", Name, Version, Additional) +} + +// PaintedString is a prettier string representation of the DTail version. +func PaintedString() string { + if !color.Colored { + return String() + } + name := color.Paint(color.Yellow, Name) + version := color.Paint(color.Blue, Version) + descr := color.Paint(color.Green, Additional) + + return fmt.Sprintf("%s %v %s", name, version, descr) +} |
