diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-26 23:32:34 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-26 23:32:34 +0200 |
| commit | 71182306390990e5ef9726e73924bc5a9f070282 (patch) | |
| tree | d576759a4f1e6cd9c11613105e9043094b0cd6f8 | |
| parent | 3820f2fe179995aa6aa12e1fd2ab9b07a7938620 (diff) | |
Implement Raft vote handling for 04c78b3c-2267-495b-9aca-84b544a1882f
| -rw-r--r-- | src/main/java/protocols/implementations/VSRaftProtocol.java | 120 | ||||
| -rw-r--r-- | src/test/java/protocols/implementations/VSRaftProtocolTest.java | 170 |
2 files changed, 289 insertions, 1 deletions
diff --git a/src/main/java/protocols/implementations/VSRaftProtocol.java b/src/main/java/protocols/implementations/VSRaftProtocol.java index 6257538..597c1e0 100644 --- a/src/main/java/protocols/implementations/VSRaftProtocol.java +++ b/src/main/java/protocols/implementations/VSRaftProtocol.java @@ -3,6 +3,7 @@ package protocols.implementations; import java.util.ArrayList; import java.util.Vector; +import core.VSInternalProcess; import core.VSMessage; import protocols.VSAbstractProtocol; @@ -87,12 +88,14 @@ public class VSRaftProtocol extends VSAbstractProtocol { * @see protocols.VSAbstractProtocol#onServerRecv(core.VSMessage) */ public void onServerRecv(VSMessage recvMessage) { + handleMessage(recvMessage); } /* (non-Javadoc) * @see protocols.VSAbstractProtocol#onClientRecv(core.VSMessage) */ public void onClientRecv(VSMessage recvMessage) { + handleMessage(recvMessage); } /* (non-Javadoc) @@ -157,10 +160,23 @@ public class VSRaftProtocol extends VSAbstractProtocol { private void becomeLeader() { isLeader = true; isCandidate = false; - removeSchedules(); + votesReceived = 0; leaderId = process.getProcessID(); lastHeartbeatTime = process.getTime(); + isServer(true); + + if (!getLongKeySet().contains("heartbeatInterval")) { + onServerInit(); + } + + boolean previousContextIsServer = currentContextIsServer(); + + currentContextIsServer(false); + removeSchedules(); + + currentContextIsServer(true); sendHeartbeat(); + currentContextIsServer(previousContextIsServer); } /** @@ -218,6 +234,7 @@ public class VSRaftProtocol extends VSAbstractProtocol { isCandidate = true; leaderId = -1; lastHeartbeatTime = process.getTime(); + isServer(true); VSMessage voteRequest = new VSMessage(); voteRequest.setString("type", "voteRequest"); @@ -241,4 +258,105 @@ public class VSRaftProtocol extends VSAbstractProtocol { lastHeartbeatTime = process.getTime(); scheduleAt(process.getTime() + getLong("heartbeatInterval")); } + + /** + * Dispatches Raft messages to the relevant handlers. + * + * @param recvMessage the received message + */ + private void handleMessage(VSMessage recvMessage) { + String messageType = recvMessage.getString("type"); + + if ("voteRequest".equals(messageType)) { + handleVoteRequest(recvMessage); + } else if ("voteResponse".equals(messageType)) { + handleVoteResponse(recvMessage); + } + } + + /** + * Handles an incoming vote request from a candidate. + * + * @param recvMessage the vote request + */ + private void handleVoteRequest(VSMessage recvMessage) { + int messageTerm = recvMessage.getInteger("term"); + int candidateId = recvMessage.getInteger("candidateId"); + boolean voteGranted = false; + + if (messageTerm >= currentTerm && + (votedFor == -1 || votedFor == candidateId)) { + becomeFollower(messageTerm, -1); + votedFor = candidateId; + voteGranted = true; + } + + VSMessage voteResponse = new VSMessage(); + voteResponse.setString("type", "voteResponse"); + voteResponse.setInteger("term", currentTerm); + voteResponse.setInteger("pid", process.getProcessID()); + voteResponse.setBoolean("voteGranted", voteGranted); + voteResponse.setInteger("targetPid", candidateId); + sendMessage(voteResponse); + } + + /** + * Handles an incoming vote response for an active election. + * + * @param recvMessage the vote response + */ + private void handleVoteResponse(VSMessage recvMessage) { + int messageTerm = recvMessage.getInteger("term"); + + if (messageTerm > currentTerm) { + becomeFollower(messageTerm, -1); + return; + } + + if (!isCandidate || !isForMe(recvMessage) || + !recvMessage.getBoolean("voteGranted") || + messageTerm != currentTerm) { + return; + } + + votesReceived++; + + if (votesReceived > getClusterSize() / 2) { + becomeLeader(); + } + } + + /** + * Checks whether a directed response is meant for this process. + * + * @param recvMessage the received message + * @return true if the message targets this process or has no target field + */ + private boolean isForMe(VSMessage recvMessage) { + if (!recvMessage.getIntegerKeySet().contains("targetPid")) { + return true; + } + + return recvMessage.getInteger("targetPid") == process.getProcessID(); + } + + /** + * Determines the cluster size used for majority calculations. + * + * @return the number of processes participating in the election + */ + private int getClusterSize() { + VSInternalProcess internalProcess = (VSInternalProcess) process; + int numProcesses = internalProcess.getSimulatorCanvas().getNumProcesses(); + + if (numProcesses > 0) { + return numProcesses; + } + + if (getVectorKeySet().contains("pids")) { + return getVector("pids").size() + 1; + } + + return 1; + } } diff --git a/src/test/java/protocols/implementations/VSRaftProtocolTest.java b/src/test/java/protocols/implementations/VSRaftProtocolTest.java index 8028711..5f0fced 100644 --- a/src/test/java/protocols/implementations/VSRaftProtocolTest.java +++ b/src/test/java/protocols/implementations/VSRaftProtocolTest.java @@ -60,6 +60,7 @@ class VSRaftProtocolTest { when(mockProcess.getSimulatorCanvas()).thenReturn(mockCanvas); when(mockCanvas.getTaskManager()).thenReturn(mockTaskManager); + when(mockCanvas.getNumProcesses()).thenReturn(3); when(mockProcess.getPrefs()).thenReturn(mockPrefs); when(mockProcess.getVectorTime()).thenReturn(mockVectorTime); when(mockVectorTime.getCopy()).thenReturn(mockVectorTime); @@ -263,6 +264,162 @@ class VSRaftProtocolTest { .isServerSchedule()); } + @Test + void testClientReceiveVoteRequestGrantsEligibleCandidate() throws Exception { + protocol.currentContextIsServer(false); + protocol.onClientInit(); + clearInvocations(mockProcess, mockTaskManager); + when(mockProcess.getTime()).thenReturn(200L, 200L); + + VSMessage voteRequest = new VSMessage(); + voteRequest.setString("type", "voteRequest"); + voteRequest.setInteger("term", 2); + voteRequest.setInteger("candidateId", 11); + + ArgumentCaptor<VSMessage> messageCaptor = + ArgumentCaptor.forClass(VSMessage.class); + ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class); + + protocol.onClientRecv(voteRequest); + + verify(mockProcess).sendMessage(messageCaptor.capture()); + verify(mockTaskManager, times(2)).removeAllTasks(any()); + verify(mockTaskManager).addTask(taskCaptor.capture()); + + VSMessage voteResponse = messageCaptor.getValue(); + assertEquals("voteResponse", voteResponse.getString("type")); + assertEquals(2, voteResponse.getInteger("term")); + assertEquals(7, voteResponse.getInteger("pid")); + assertTrue(voteResponse.getBoolean("voteGranted")); + assertEquals(11, voteResponse.getInteger("targetPid")); + assertEquals(2, getIntField("currentTerm")); + assertEquals(11, getIntField("votedFor")); + assertFalse(getBooleanField("isCandidate")); + assertFalse(getBooleanField("isLeader")); + assertEquals(4700L, taskCaptor.getValue().getTaskTime()); + } + + @Test + void testClientReceiveVoteRequestDeniesWhenAlreadyVotedForOtherCandidate() + throws Exception { + setIntField("currentTerm", 3); + setIntField("votedFor", 9); + protocol.currentContextIsServer(false); + + VSMessage voteRequest = new VSMessage(); + voteRequest.setString("type", "voteRequest"); + voteRequest.setInteger("term", 3); + voteRequest.setInteger("candidateId", 11); + + ArgumentCaptor<VSMessage> messageCaptor = + ArgumentCaptor.forClass(VSMessage.class); + + protocol.onClientRecv(voteRequest); + + verify(mockProcess).sendMessage(messageCaptor.capture()); + verify(mockTaskManager, never()).removeAllTasks(any()); + verify(mockTaskManager, never()).addTask(any()); + + VSMessage voteResponse = messageCaptor.getValue(); + assertEquals("voteResponse", voteResponse.getString("type")); + assertEquals(3, voteResponse.getInteger("term")); + assertFalse(voteResponse.getBoolean("voteGranted")); + assertEquals(9, getIntField("votedFor")); + assertEquals(3, getIntField("currentTerm")); + } + + @Test + void testVoteResponseMajorityPromotesCandidateToLeader() throws Exception { + protocol.currentContextIsServer(false); + setIntField("currentTerm", 3); + setIntField("votesReceived", 1); + setBooleanField("isCandidate", true); + when(mockProcess.getTime()).thenReturn(300L, 300L, 300L); + + VSMessage voteResponse = new VSMessage(); + voteResponse.setString("type", "voteResponse"); + voteResponse.setInteger("term", 3); + voteResponse.setInteger("pid", 2); + voteResponse.setBoolean("voteGranted", true); + voteResponse.setInteger("targetPid", 7); + + ArgumentCaptor<VSMessage> messageCaptor = + ArgumentCaptor.forClass(VSMessage.class); + ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class); + + protocol.onServerRecv(voteResponse); + + verify(mockProcess).sendMessage(messageCaptor.capture()); + verify(mockTaskManager).removeAllTasks(any()); + verify(mockTaskManager).addTask(taskCaptor.capture()); + + VSMessage heartbeat = messageCaptor.getValue(); + assertEquals("heartbeat", heartbeat.getString("type")); + assertEquals(3, heartbeat.getInteger("term")); + assertEquals(7, heartbeat.getInteger("leaderId")); + assertTrue(getBooleanField("isLeader")); + assertFalse(getBooleanField("isCandidate")); + assertEquals(7, getIntField("leaderId")); + assertTrue(protocol.isServer()); + assertEquals(1800L, taskCaptor.getValue().getTaskTime()); + } + + @Test + void testVoteResponseForDifferentTargetDoesNotCount() throws Exception { + protocol.currentContextIsServer(false); + setIntField("currentTerm", 3); + setIntField("votesReceived", 1); + setBooleanField("isCandidate", true); + + VSMessage voteResponse = new VSMessage(); + voteResponse.setString("type", "voteResponse"); + voteResponse.setInteger("term", 3); + voteResponse.setInteger("pid", 2); + voteResponse.setBoolean("voteGranted", true); + voteResponse.setInteger("targetPid", 99); + + protocol.onServerRecv(voteResponse); + + verify(mockProcess, never()).sendMessage(any()); + verify(mockTaskManager, never()).removeAllTasks(any()); + verify(mockTaskManager, never()).addTask(any()); + assertEquals(1, getIntField("votesReceived")); + assertTrue(getBooleanField("isCandidate")); + assertFalse(getBooleanField("isLeader")); + } + + @Test + void testHigherTermVoteResponseDemotesCandidateToFollower() throws Exception { + protocol.currentContextIsServer(false); + protocol.onClientInit(); + clearInvocations(mockProcess, mockTaskManager); + setIntField("currentTerm", 3); + setIntField("votesReceived", 2); + setBooleanField("isCandidate", true); + when(mockProcess.getTime()).thenReturn(500L, 500L); + + VSMessage voteResponse = new VSMessage(); + voteResponse.setString("type", "voteResponse"); + voteResponse.setInteger("term", 4); + voteResponse.setInteger("pid", 2); + voteResponse.setBoolean("voteGranted", false); + voteResponse.setInteger("targetPid", 7); + + ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class); + + protocol.onServerRecv(voteResponse); + + verify(mockProcess, never()).sendMessage(any()); + verify(mockTaskManager, times(2)).removeAllTasks(any()); + verify(mockTaskManager).addTask(taskCaptor.capture()); + assertEquals(4, getIntField("currentTerm")); + assertEquals(-1, getIntField("votedFor")); + assertEquals(0, getIntField("votesReceived")); + assertFalse(getBooleanField("isCandidate")); + assertFalse(getBooleanField("isLeader")); + assertEquals(5000L, taskCaptor.getValue().getTaskTime()); + } + private void invokeBecomeFollower(int term, int leaderId) throws Exception { Method method = VSRaftProtocol.class.getDeclaredMethod( "becomeFollower", int.class, int.class); @@ -270,6 +427,19 @@ class VSRaftProtocolTest { method.invoke(protocol, term, leaderId); } + private void setIntField(String fieldName, int value) throws Exception { + Field field = VSRaftProtocol.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.setInt(protocol, value); + } + + private void setBooleanField(String fieldName, boolean value) + throws Exception { + Field field = VSRaftProtocol.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.setBoolean(protocol, value); + } + private int getIntField(String fieldName) throws Exception { Field field = VSRaftProtocol.class.getDeclaredField(fieldName); field.setAccessible(true); |
