diff options
| -rw-r--r-- | internal/config/config.go | 86 | ||||
| -rw-r--r-- | internal/config/config_test.go | 6 | ||||
| -rw-r--r-- | internal/quorum/quorum_test.go | 31 | ||||
| -rw-r--r-- | internal/vote/vote_test.go | 16 |
4 files changed, 72 insertions, 67 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index 31b6644..12f1adc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,66 +21,78 @@ type Config struct { nodeNumberCache map[string]int } -// TODO Refactor all unit tests to use a constructor here -// and to initialize the nodeNumberCache inside of the constructor! -func New(configFile string) (Config, error) { - var c Config +func New(arg any) (Config, error) { + var conf *Config + + switch arg := arg.(type) { + case string: + // Used to read a config from a file. + conf = &Config{} + if err := conf.readConfigFile(arg); err != nil { + return *conf, err + } + case Config: + // Used to initialize a custom config from unit tests. + conf = &arg + default: + log.Fatal("unable to initialize config") + } + + if conf.LoopIntervalS == 0 { + conf.LoopIntervalS = 10 + } + + if conf.MyID == "" { + hostname, err := os.Hostname() + if err != nil { + return *conf, err + } + conf.MyID = hostname + } + + conf.nodeNumberCache = make(map[string]int, len(conf.Nodes)) + for i, node := range conf.Nodes { + conf.nodeNumberCache[utils.StripPort(node)] = i + } + + return *conf, nil +} +func (conf *Config) readConfigFile(configFile string) error { file, err := os.Open(configFile) if err != nil { - return c, err + return err } defer file.Close() bytes, err := io.ReadAll(file) if err != nil { - return c, err + return err } - err = json.Unmarshal(bytes, &c) + err = json.Unmarshal(bytes, conf) if err != nil { - return c, err - } - - if c.LoopIntervalS == 0 { - c.LoopIntervalS = 10 - } - - if c.MyID == "" { - hostname, err := os.Hostname() - if err != nil { - return c, err - } - c.MyID = hostname + return nil } - return c, nil + return nil } -func (c *Config) NodeNumber(node string) int { - if c.nodeNumberCache == nil { - c.nodeNumberCache = make(map[string]int, len(c.Nodes)) - } - nodeNumber, ok := c.nodeNumberCache[node] +func (conf Config) NodeNumber(node string) int { + node = utils.StripPort(node) + nodeNumber, ok := conf.nodeNumberCache[node] if ok { return nodeNumber } - for i, node_ := range c.Nodes { - if node == utils.StripPort(node_) { - c.nodeNumberCache[node] = i - return i - } - } - log.Println("config:", fmt.Errorf("node %s not found - it will affect it's score!", node)) return 0 } -func (c Config) IsNode(remoteAddr string) bool { +func (conf Config) IsNode(remoteAddr string) bool { remoteAddr = utils.StripPort(remoteAddr) - for _, node := range c.Nodes { + for _, node := range conf.Nodes { if remoteAddr == utils.StripPort(node) { return true } @@ -89,10 +101,10 @@ func (c Config) IsNode(remoteAddr string) bool { return false } -func (c Config) IsNodeWithLookup(remoteAddr string, lookupIP func(string) ([]net.IP, error)) bool { +func (conf Config) IsNodeWithLookup(remoteAddr string, lookupIP func(string) ([]net.IP, error)) bool { remoteAddr = utils.StripPort(remoteAddr) - for _, node := range c.Nodes { + for _, node := range conf.Nodes { ips, err := lookupIP(utils.StripPort(node)) if err != nil { log.Println("config:", err) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 7927664..af403ec 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -8,7 +8,7 @@ import ( func TestNodeNumber(t *testing.T) { t.Parallel() - conf := Config{Nodes: []string{"localhost:1234", "hamburger:4321"}} + conf, _ := New(Config{Nodes: []string{"localhost:1234", "hamburger:4321"}}) num := conf.NodeNumber("localhost") if num != 0 { @@ -23,7 +23,7 @@ func TestNodeNumber(t *testing.T) { func TestIsNode(t *testing.T) { t.Parallel() - conf := Config{Nodes: []string{"localhost:1234", "hamburger:4321"}} + conf, _ := New(Config{Nodes: []string{"localhost:1234", "hamburger:4321"}}) remoteAddr := "localhost:323232" if !conf.IsNode(remoteAddr) { @@ -38,7 +38,7 @@ func TestIsNode(t *testing.T) { func TestIsNodeWithLookup(t *testing.T) { t.Parallel() - conf := Config{Nodes: []string{"localhost:1234", "hamburger:4321"}} + conf, _ := New(Config{Nodes: []string{"localhost:1234", "hamburger:4321"}}) lookupIP := func(addr string) ([]net.IP, error) { switch addr { diff --git a/internal/quorum/quorum_test.go b/internal/quorum/quorum_test.go index a5fcbbf..b88b4fb 100644 --- a/internal/quorum/quorum_test.go +++ b/internal/quorum/quorum_test.go @@ -12,10 +12,11 @@ var inOneHour = time.Now().Add(1 * time.Hour) func TestScore(t *testing.T) { t.Parallel() - var ( - conf = config.Config{Nodes: []string{"foo:1234", "bar:4321", "baz:3444"}} - quo = New(conf) + + conf, _ := config.New( + config.Config{Nodes: []string{"foo:1234", "bar:4321", "baz:3444"}}, ) + quo := New(conf) vote1, _ := vote.New(conf, []string{"foo", "bar"}) vote1.FromID = "foo" @@ -68,10 +69,10 @@ func TestTieScore(t *testing.T) { t.Run("First tie score test", func(t *testing.T) { // If it is a tie, the first particpant (here: "foo") will win. - var ( - conf = config.Config{Nodes: []string{"foo:1234", "bar:4321", "baz:3444"}} - quo = New(conf) + conf, _ := config.New( + config.Config{Nodes: []string{"foo:1234", "bar:4321", "baz:3444"}}, ) + quo := New(conf) addVotes(conf, quo) scores := quo.scores() @@ -91,10 +92,10 @@ func TestTieScore(t *testing.T) { t.Run("Second tie score test", func(t *testing.T) { // If it is a tie, the first particpant (here: "bar") will win. - var ( - conf = config.Config{Nodes: []string{"bar:1234", "foo:4321", "baz:3444"}} - quo = New(conf) + conf, _ := config.New( + config.Config{Nodes: []string{"bar:1234", "foo:4321", "baz:3444"}}, ) + quo := New(conf) addVotes(conf, quo) scores := quo.scores() @@ -114,10 +115,10 @@ func TestTieScore(t *testing.T) { } func TestExpire(t *testing.T) { - var ( - conf = config.Config{Nodes: []string{"foo:1234", "bar:4321", "bay:2212"}} - quo = New(conf) + conf, _ := config.New( + config.Config{Nodes: []string{"foo:1234", "bar:4321", "bay:2212"}}, ) + quo := New(conf) vote1, _ := vote.New(conf, []string{"bar", "baz", "bay"}) vote1.FromID = "foo" @@ -148,10 +149,10 @@ func TestExpire(t *testing.T) { } func TestLiveNodes(t *testing.T) { - var ( - conf = config.Config{Nodes: []string{"foo:1234", "bay:4321"}} - quo = New(conf) + conf, _ := config.New( + config.Config{Nodes: []string{"foo:1234", "bay:4321"}}, ) + quo := New(conf) vote1, _ := vote.New(conf, []string{"bar", "baz", "bay"}) vote1.ExpiresAt = inOneHour diff --git a/internal/vote/vote_test.go b/internal/vote/vote_test.go index b5480b9..25860c7 100644 --- a/internal/vote/vote_test.go +++ b/internal/vote/vote_test.go @@ -10,10 +10,7 @@ import ( func TestVote(t *testing.T) { t.Parallel() - conf := config.Config{ - MyID: "foo.zone", - } - + conf, _ := config.New(config.Config{MyID: "foo.zone"}) v, _ := New(conf, []string{"foo", "bar", "baz", "bay"}) if v.FromID != "foo.zone" { @@ -36,10 +33,7 @@ func TestVote(t *testing.T) { func TestVoteExpiry(t *testing.T) { t.Parallel() - conf := config.Config{ - MyID: "foo.zone", - } - + conf, _ := config.New(config.Config{MyID: "foo.zone"}) v, _ := New(conf, []string{"foo", "bar", "baz", "bay"}) // Set expiry 1h into the future @@ -58,11 +52,9 @@ func TestVoteExpiry(t *testing.T) { func TestMarshalling(t *testing.T) { t.Parallel() - conf := config.Config{ - MyID: "foo.zone", - } - + conf, _ := config.New(config.Config{MyID: "foo.zone"}) v, _ := New(conf, []string{"foo", "bar", "baz", "bay"}) + jsonStr, err := v.ToJSON() if err != nil { t.Errorf("unable to serialize vote to json: %v", err) |
