summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2023-05-19 01:42:02 +0300
committerPaul Buetow <paul@buetow.org>2023-05-19 01:42:02 +0300
commitca585bde1c27777917b9454fe1d9c3d736de143d (patch)
tree702c4241ede146e07bd67cc536df132ba94a7043 /internal
parent5dc583de26cc7fc9816ce6b44ca4e80e723b8311 (diff)
excluding votes for non-participants
Diffstat (limited to 'internal')
-rw-r--r--internal/config.go14
-rw-r--r--internal/config_test.go14
-rw-r--r--internal/tcpserver.go8
-rw-r--r--internal/vote.go14
-rw-r--r--internal/vote_test.go19
5 files changed, 47 insertions, 22 deletions
diff --git a/internal/config.go b/internal/config.go
index 93549e2..17ec25b 100644
--- a/internal/config.go
+++ b/internal/config.go
@@ -38,11 +38,19 @@ func newConfig(configFile string) (config, error) {
}
func (c config) isParticipant(remoteAddr string) bool {
- return c.isParticipantWithLookup(remoteAddr, net.LookupIP)
+ remoteAddr = stripPort(remoteAddr)
+
+ for _, participant := range c.Participants {
+ if remoteAddr == stripPort(participant) {
+ return true
+ }
+ }
+
+ return false
}
func (c config) isParticipantWithLookup(remoteAddr string, lookupIP func(string) ([]net.IP, error)) bool {
- remoteIP := stripPort(remoteAddr)
+ remoteAddr = stripPort(remoteAddr)
for _, participant := range c.Participants {
ips, err := lookupIP(stripPort(participant))
@@ -52,7 +60,7 @@ func (c config) isParticipantWithLookup(remoteAddr string, lookupIP func(string)
}
for _, ip := range ips {
- if remoteIP == ip.String() {
+ if remoteAddr == ip.String() {
return true
}
}
diff --git a/internal/config_test.go b/internal/config_test.go
index 1a9ed69..67a9fff 100644
--- a/internal/config_test.go
+++ b/internal/config_test.go
@@ -15,6 +15,20 @@ func TestStripPort(t *testing.T) {
func TestIsParticipant(t *testing.T) {
config := config{Participants: []string{"localhost:1234", "hamburger:4321"}}
+ remoteAddr := "localhost:323232"
+ if !config.isParticipant(remoteAddr) {
+ t.Errorf("%s should be participant of %v", remoteAddr, config.Participants)
+ }
+
+ remoteAddr = "foo.zone:2345"
+ if config.isParticipant(remoteAddr) {
+ t.Errorf("%s should not be participant of %v", remoteAddr, config.Participants)
+ }
+}
+
+func TestIsParticipantWithLookup(t *testing.T) {
+ config := config{Participants: []string{"localhost:1234", "hamburger:4321"}}
+
lookupIP := func(addr string) ([]net.IP, error) {
switch addr {
case "localhost":
diff --git a/internal/tcpserver.go b/internal/tcpserver.go
index b689419..024f563 100644
--- a/internal/tcpserver.go
+++ b/internal/tcpserver.go
@@ -24,18 +24,18 @@ func startTcpServer(ctx context.Context, config config, ch chan<- vote) error {
continue
}
- if !config.isParticipant(conn.RemoteAddr().String()) {
+ if !config.isParticipantWithLookup(conn.RemoteAddr().String(), net.LookupIP) {
log.Printf("Denying connection, peer not a participant: %v\n", conn.RemoteAddr().String())
conn.Close()
continue
}
log.Printf("Client connected: %s\n", conn.RemoteAddr().String())
- go handleConnection(ctx, conn, ch)
+ go handleConnection(ctx, config, conn, ch)
}
}
-func handleConnection(ctx context.Context, conn net.Conn, ch chan<- vote) {
+func handleConnection(ctx context.Context, config config, conn net.Conn, ch chan<- vote) {
defer conn.Close()
remoteAddr := conn.RemoteAddr().String()
@@ -53,7 +53,7 @@ func handleConnection(ctx context.Context, conn net.Conn, ch chan<- vote) {
}
log.Printf("Received message from %s: %s", remoteAddr, message)
- ch <- newVote(remoteAddr, message)
+ ch <- newVote(config, remoteAddr, message)
conn.Write([]byte(message))
}
diff --git a/internal/vote.go b/internal/vote.go
index 80fdb9f..3aadd53 100644
--- a/internal/vote.go
+++ b/internal/vote.go
@@ -1,6 +1,7 @@
package internal
import (
+ "log"
"strings"
"time"
)
@@ -11,6 +12,15 @@ type vote struct {
time time.Time
}
-func newVote(from, message string) vote {
- return vote{stripPort(from), strings.Split(strings.TrimSpace(message), " "), time.Now()}
+func newVote(config config, from, message string) vote {
+ var ids []string
+ for _, id := range strings.Split(strings.TrimSpace(message), " ") {
+ if !config.isParticipant(id) {
+ log.Printf("%s is not a participant, excluding from the vote", id)
+ continue
+ }
+ ids = append(ids, id)
+ }
+
+ return vote{stripPort(from), ids, time.Now()}
}
diff --git a/internal/vote_test.go b/internal/vote_test.go
index d5afe34..50f8a2f 100644
--- a/internal/vote_test.go
+++ b/internal/vote_test.go
@@ -5,29 +5,22 @@ import (
)
func TestVote(t *testing.T) {
- v := newVote("earth:334234", " foo bar baz bay\n")
+ config := config{Participants: []string{"foo:1234", "bay:4321"}}
+ v := newVote(config, "earth:334234", " foo bar baz bay\n")
if v.from != "earth" {
t.Errorf("Expected vote to come from earth but came from %s", v.from)
}
- if len(v.ids) != 4 {
- t.Errorf("Expected vote length to be 4 but is %d", len(v.ids))
+ if len(v.ids) != 2 {
+ t.Errorf("Expected vote length to be 2 but is %d", len(v.ids))
}
if v.ids[0] != "foo" {
t.Errorf("Expected vote 1 to be foo but is %s", v.ids[0])
}
- if v.ids[1] != "bar" {
- t.Errorf("Expected vote 2 to be bar but is %s", v.ids[1])
- }
-
- if v.ids[2] != "baz" {
- t.Errorf("Expected vote 3 to be baz but is %s", v.ids[2])
- }
-
- if v.ids[3] != "bay" {
- t.Errorf("Expected vote 3 to be bay but is %s", v.ids[3])
+ if v.ids[1] != "bay" {
+ t.Errorf("Expected vote 2 to be bay but is %s", v.ids[1])
}
}