summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/config/config.go86
-rw-r--r--internal/config/config_test.go6
-rw-r--r--internal/quorum/quorum_test.go31
-rw-r--r--internal/vote/vote_test.go16
4 files changed, 72 insertions, 67 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
index 31b6644..12f1adc 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -21,66 +21,78 @@ type Config struct {
nodeNumberCache map[string]int
}
-// TODO Refactor all unit tests to use a constructor here
-// and to initialize the nodeNumberCache inside of the constructor!
-func New(configFile string) (Config, error) {
- var c Config
+func New(arg any) (Config, error) {
+ var conf *Config
+
+ switch arg := arg.(type) {
+ case string:
+ // Used to read a config from a file.
+ conf = &Config{}
+ if err := conf.readConfigFile(arg); err != nil {
+ return *conf, err
+ }
+ case Config:
+ // Used to initialize a custom config from unit tests.
+ conf = &arg
+ default:
+ log.Fatal("unable to initialize config")
+ }
+
+ if conf.LoopIntervalS == 0 {
+ conf.LoopIntervalS = 10
+ }
+
+ if conf.MyID == "" {
+ hostname, err := os.Hostname()
+ if err != nil {
+ return *conf, err
+ }
+ conf.MyID = hostname
+ }
+
+ conf.nodeNumberCache = make(map[string]int, len(conf.Nodes))
+ for i, node := range conf.Nodes {
+ conf.nodeNumberCache[utils.StripPort(node)] = i
+ }
+
+ return *conf, nil
+}
+func (conf *Config) readConfigFile(configFile string) error {
file, err := os.Open(configFile)
if err != nil {
- return c, err
+ return err
}
defer file.Close()
bytes, err := io.ReadAll(file)
if err != nil {
- return c, err
+ return err
}
- err = json.Unmarshal(bytes, &c)
+ err = json.Unmarshal(bytes, conf)
if err != nil {
- return c, err
- }
-
- if c.LoopIntervalS == 0 {
- c.LoopIntervalS = 10
- }
-
- if c.MyID == "" {
- hostname, err := os.Hostname()
- if err != nil {
- return c, err
- }
- c.MyID = hostname
+ return nil
}
- return c, nil
+ return nil
}
-func (c *Config) NodeNumber(node string) int {
- if c.nodeNumberCache == nil {
- c.nodeNumberCache = make(map[string]int, len(c.Nodes))
- }
- nodeNumber, ok := c.nodeNumberCache[node]
+func (conf Config) NodeNumber(node string) int {
+ node = utils.StripPort(node)
+ nodeNumber, ok := conf.nodeNumberCache[node]
if ok {
return nodeNumber
}
- for i, node_ := range c.Nodes {
- if node == utils.StripPort(node_) {
- c.nodeNumberCache[node] = i
- return i
- }
- }
-
log.Println("config:", fmt.Errorf("node %s not found - it will affect it's score!", node))
return 0
}
-func (c Config) IsNode(remoteAddr string) bool {
+func (conf Config) IsNode(remoteAddr string) bool {
remoteAddr = utils.StripPort(remoteAddr)
- for _, node := range c.Nodes {
+ for _, node := range conf.Nodes {
if remoteAddr == utils.StripPort(node) {
return true
}
@@ -89,10 +101,10 @@ func (c Config) IsNode(remoteAddr string) bool {
return false
}
-func (c Config) IsNodeWithLookup(remoteAddr string, lookupIP func(string) ([]net.IP, error)) bool {
+func (conf Config) IsNodeWithLookup(remoteAddr string, lookupIP func(string) ([]net.IP, error)) bool {
remoteAddr = utils.StripPort(remoteAddr)
- for _, node := range c.Nodes {
+ for _, node := range conf.Nodes {
ips, err := lookupIP(utils.StripPort(node))
if err != nil {
log.Println("config:", err)
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index 7927664..af403ec 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -8,7 +8,7 @@ import (
func TestNodeNumber(t *testing.T) {
t.Parallel()
- conf := Config{Nodes: []string{"localhost:1234", "hamburger:4321"}}
+ conf, _ := New(Config{Nodes: []string{"localhost:1234", "hamburger:4321"}})
num := conf.NodeNumber("localhost")
if num != 0 {
@@ -23,7 +23,7 @@ func TestNodeNumber(t *testing.T) {
func TestIsNode(t *testing.T) {
t.Parallel()
- conf := Config{Nodes: []string{"localhost:1234", "hamburger:4321"}}
+ conf, _ := New(Config{Nodes: []string{"localhost:1234", "hamburger:4321"}})
remoteAddr := "localhost:323232"
if !conf.IsNode(remoteAddr) {
@@ -38,7 +38,7 @@ func TestIsNode(t *testing.T) {
func TestIsNodeWithLookup(t *testing.T) {
t.Parallel()
- conf := Config{Nodes: []string{"localhost:1234", "hamburger:4321"}}
+ conf, _ := New(Config{Nodes: []string{"localhost:1234", "hamburger:4321"}})
lookupIP := func(addr string) ([]net.IP, error) {
switch addr {
diff --git a/internal/quorum/quorum_test.go b/internal/quorum/quorum_test.go
index a5fcbbf..b88b4fb 100644
--- a/internal/quorum/quorum_test.go
+++ b/internal/quorum/quorum_test.go
@@ -12,10 +12,11 @@ var inOneHour = time.Now().Add(1 * time.Hour)
func TestScore(t *testing.T) {
t.Parallel()
- var (
- conf = config.Config{Nodes: []string{"foo:1234", "bar:4321", "baz:3444"}}
- quo = New(conf)
+
+ conf, _ := config.New(
+ config.Config{Nodes: []string{"foo:1234", "bar:4321", "baz:3444"}},
)
+ quo := New(conf)
vote1, _ := vote.New(conf, []string{"foo", "bar"})
vote1.FromID = "foo"
@@ -68,10 +69,10 @@ func TestTieScore(t *testing.T) {
t.Run("First tie score test", func(t *testing.T) {
// If it is a tie, the first particpant (here: "foo") will win.
- var (
- conf = config.Config{Nodes: []string{"foo:1234", "bar:4321", "baz:3444"}}
- quo = New(conf)
+ conf, _ := config.New(
+ config.Config{Nodes: []string{"foo:1234", "bar:4321", "baz:3444"}},
)
+ quo := New(conf)
addVotes(conf, quo)
scores := quo.scores()
@@ -91,10 +92,10 @@ func TestTieScore(t *testing.T) {
t.Run("Second tie score test", func(t *testing.T) {
// If it is a tie, the first particpant (here: "bar") will win.
- var (
- conf = config.Config{Nodes: []string{"bar:1234", "foo:4321", "baz:3444"}}
- quo = New(conf)
+ conf, _ := config.New(
+ config.Config{Nodes: []string{"bar:1234", "foo:4321", "baz:3444"}},
)
+ quo := New(conf)
addVotes(conf, quo)
scores := quo.scores()
@@ -114,10 +115,10 @@ func TestTieScore(t *testing.T) {
}
func TestExpire(t *testing.T) {
- var (
- conf = config.Config{Nodes: []string{"foo:1234", "bar:4321", "bay:2212"}}
- quo = New(conf)
+ conf, _ := config.New(
+ config.Config{Nodes: []string{"foo:1234", "bar:4321", "bay:2212"}},
)
+ quo := New(conf)
vote1, _ := vote.New(conf, []string{"bar", "baz", "bay"})
vote1.FromID = "foo"
@@ -148,10 +149,10 @@ func TestExpire(t *testing.T) {
}
func TestLiveNodes(t *testing.T) {
- var (
- conf = config.Config{Nodes: []string{"foo:1234", "bay:4321"}}
- quo = New(conf)
+ conf, _ := config.New(
+ config.Config{Nodes: []string{"foo:1234", "bay:4321"}},
)
+ quo := New(conf)
vote1, _ := vote.New(conf, []string{"bar", "baz", "bay"})
vote1.ExpiresAt = inOneHour
diff --git a/internal/vote/vote_test.go b/internal/vote/vote_test.go
index b5480b9..25860c7 100644
--- a/internal/vote/vote_test.go
+++ b/internal/vote/vote_test.go
@@ -10,10 +10,7 @@ import (
func TestVote(t *testing.T) {
t.Parallel()
- conf := config.Config{
- MyID: "foo.zone",
- }
-
+ conf, _ := config.New(config.Config{MyID: "foo.zone"})
v, _ := New(conf, []string{"foo", "bar", "baz", "bay"})
if v.FromID != "foo.zone" {
@@ -36,10 +33,7 @@ func TestVote(t *testing.T) {
func TestVoteExpiry(t *testing.T) {
t.Parallel()
- conf := config.Config{
- MyID: "foo.zone",
- }
-
+ conf, _ := config.New(config.Config{MyID: "foo.zone"})
v, _ := New(conf, []string{"foo", "bar", "baz", "bay"})
// Set expiry 1h into the future
@@ -58,11 +52,9 @@ func TestVoteExpiry(t *testing.T) {
func TestMarshalling(t *testing.T) {
t.Parallel()
- conf := config.Config{
- MyID: "foo.zone",
- }
-
+ conf, _ := config.New(config.Config{MyID: "foo.zone"})
v, _ := New(conf, []string{"foo", "bar", "baz", "bay"})
+
jsonStr, err := v.ToJSON()
if err != nil {
t.Errorf("unable to serialize vote to json: %v", err)