summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-27 06:19:38 +0200
committerPaul Buetow <paul@buetow.org>2026-03-27 06:19:38 +0200
commitf12114c1d9ec50f20f3df3e9c6e335e00f186c10 (patch)
treea097757d3c6eda3961da4407c5c87c00587ff6be /src
parentc5e06e480d01f4f87d02b5f04e873f44a679c741 (diff)
Fix final Raft append review issues b85586a4-4eb9-4686-93c7-0ab14173baa5
Diffstat (limited to 'src')
-rw-r--r--src/main/java/protocols/implementations/VSRaftProtocol.java12
-rw-r--r--src/test/java/protocols/implementations/VSRaftProtocolTest.java137
2 files changed, 116 insertions, 33 deletions
diff --git a/src/main/java/protocols/implementations/VSRaftProtocol.java b/src/main/java/protocols/implementations/VSRaftProtocol.java
index 8c36b68..d0066e0 100644
--- a/src/main/java/protocols/implementations/VSRaftProtocol.java
+++ b/src/main/java/protocols/implementations/VSRaftProtocol.java
@@ -387,11 +387,6 @@ public class VSRaftProtocol extends VSAbstractProtocol {
if (messageTerm > currentTerm) {
becomeFollower(messageTerm, messageLeaderId);
- } else if (messageTerm == currentTerm) {
- leaderId = messageLeaderId;
- isLeader = false;
- isCandidate = false;
- resetElectionTimeout();
} else {
return;
}
@@ -400,6 +395,13 @@ public class VSRaftProtocol extends VSAbstractProtocol {
return;
}
+ if (messageTerm == currentTerm) {
+ leaderId = messageLeaderId;
+ isLeader = false;
+ isCandidate = false;
+ resetElectionTimeout();
+ }
+
logIndex = messageLogIndex;
VSMessage appendAck = new VSMessage();
diff --git a/src/test/java/protocols/implementations/VSRaftProtocolTest.java b/src/test/java/protocols/implementations/VSRaftProtocolTest.java
index 49f87f4..c9440b1 100644
--- a/src/test/java/protocols/implementations/VSRaftProtocolTest.java
+++ b/src/test/java/protocols/implementations/VSRaftProtocolTest.java
@@ -24,6 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.clearInvocations;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -524,8 +525,8 @@ class VSRaftProtocolTest {
protocol.onClientRecv(appendEntry);
verify(mockProcess).sendMessage(messageCaptor.capture());
- verify(mockTaskManager, times(2)).removeAllTasks(any());
- verify(mockTaskManager).addTask(taskCaptor.capture());
+ verify(mockTaskManager, times(3)).removeAllTasks(any());
+ verify(mockTaskManager, times(2)).addTask(taskCaptor.capture());
VSMessage appendAck = messageCaptor.getValue();
assertEquals("appendAck", appendAck.getString("type"));
@@ -548,6 +549,9 @@ class VSRaftProtocolTest {
protocol.onClientInit();
clearInvocations(mockProcess, mockTaskManager);
setIntField("currentTerm", 2);
+ setIntField("leaderId", 5);
+ setBooleanField("isCandidate", true);
+ long electionDeadline = getLongField("electionDeadline");
when(mockProcess.getTime()).thenReturn(600L);
VSMessage appendEntry = new VSMessage();
@@ -560,49 +564,63 @@ class VSRaftProtocolTest {
protocol.onClientRecv(appendEntry);
verify(mockProcess, never()).sendMessage(any());
- verify(mockTaskManager).removeAllTasks(any());
- verify(mockTaskManager).addTask(any());
+ verify(mockTaskManager, never()).removeAllTasks(any());
+ verify(mockTaskManager, never()).addTask(any());
assertEquals(0, getIntField("logIndex"));
assertEquals(2, getIntField("currentTerm"));
- assertEquals(11, getIntField("leaderId"));
+ assertEquals(5, getIntField("leaderId"));
+ assertTrue(getBooleanField("isCandidate"));
+ assertEquals(electionDeadline, getLongField("electionDeadline"));
}
@Test
- void testLeaderAppendQuorumStateDrainsAndCommitsAfterFollowerAcks()
+ void testLeaderAppendQuorumStateDrainsAndCommitsAfterFollowerRoundTrips()
throws Exception {
- protocol.onStart();
- clearInvocations(mockProcess, mockTaskManager);
+ LeaderHarness leaderHarness = createLeaderHarness(11, 300L);
+ leaderHarness.protocol.onStart();
- ArrayList<Integer> ackPids = getAckPids();
+ ArrayList<VSMessage> sentMessages = leaderHarness.protocol.getSentMessages();
+ assertEquals(2, sentMessages.size());
+ VSMessage appendEntry = sentMessages.get(1);
+ ArrayList<Integer> ackPids = getAckPids(leaderHarness.protocol);
assertEquals(2, ackPids.size());
- assertEquals(1, getIntField("logIndex"));
+ assertEquals(1, getIntField(leaderHarness.protocol, "logIndex"));
- VSMessage firstAck = new VSMessage();
- firstAck.setString("type", "appendAck");
- firstAck.setInteger("term", 0);
- firstAck.setInteger("pid", 2);
- firstAck.setInteger("logIndex", 1);
- firstAck.setInteger("targetPid", 7);
+ protocol.currentContextIsServer(false);
+ protocol.onClientInit();
+ setIntField("currentTerm", -1);
+ clearInvocations(mockProcess, mockTaskManager);
+ when(mockProcess.getProcessID()).thenReturn(2);
+ when(mockProcess.getTime()).thenReturn(700L, 700L);
- protocol.onServerRecv(firstAck);
+ ArgumentCaptor<VSMessage> followerAckCaptor =
+ ArgumentCaptor.forClass(VSMessage.class);
+ protocol.onClientRecv(appendEntry);
+ verify(mockProcess).sendMessage(followerAckCaptor.capture());
- verify(mockProcess, never()).log(anyString());
+ leaderHarness.protocol.onServerRecv(followerAckCaptor.getValue());
+
+ verify(leaderHarness.process, never()).log(anyString());
assertEquals(1, ackPids.size());
assertTrue(ackPids.contains(3));
- assertEquals(0, getIntField("commitIndex"));
+ assertEquals(0, getIntField(leaderHarness.protocol, "commitIndex"));
+
+ protocol.onClientReset();
+ protocol.currentContextIsServer(false);
+ protocol.onClientInit();
+ setIntField("currentTerm", -1);
+ clearInvocations(mockProcess, mockTaskManager);
+ when(mockProcess.getProcessID()).thenReturn(3);
+ when(mockProcess.getTime()).thenReturn(800L, 800L);
- VSMessage secondAck = new VSMessage();
- secondAck.setString("type", "appendAck");
- secondAck.setInteger("term", 0);
- secondAck.setInteger("pid", 3);
- secondAck.setInteger("logIndex", 1);
- secondAck.setInteger("targetPid", 7);
+ protocol.onClientRecv(appendEntry);
+ verify(mockProcess).sendMessage(followerAckCaptor.capture());
- protocol.onServerRecv(secondAck);
+ leaderHarness.protocol.onServerRecv(followerAckCaptor.getAllValues().get(1));
- verify(mockProcess).log("Committed log index 1");
+ verify(leaderHarness.process).log("Committed log index 1");
assertTrue(ackPids.isEmpty());
- assertEquals(1, getIntField("commitIndex"));
+ assertEquals(1, getIntField(leaderHarness.protocol, "commitIndex"));
}
@Test
@@ -686,4 +704,67 @@ class VSRaftProtocolTest {
field.setAccessible(true);
return field.getLong(protocol);
}
+
+ private LeaderHarness createLeaderHarness(int pid, long time) {
+ CapturingRaftProtocol peerProtocol = new CapturingRaftProtocol();
+ VSInternalProcess peerProcess = mock(VSInternalProcess.class);
+ VSSimulatorVisualization peerCanvas = mock(VSSimulatorVisualization.class);
+ VSTaskManager peerTaskManager = mock(VSTaskManager.class);
+ VSPrefs peerPrefs = mock(VSPrefs.class);
+ VSVectorTime peerVectorTime = mock(VSVectorTime.class);
+
+ peerProtocol.process = peerProcess;
+ peerProtocol.prefs = peerPrefs;
+ peerProtocol.isServer(true);
+ peerProtocol.currentContextIsServer(true);
+
+ when(peerProcess.getSimulatorCanvas()).thenReturn(peerCanvas);
+ when(peerCanvas.getTaskManager()).thenReturn(peerTaskManager);
+ when(peerCanvas.getNumProcesses()).thenReturn(3);
+ when(peerProcess.getPrefs()).thenReturn(peerPrefs);
+ when(peerProcess.getVectorTime()).thenReturn(peerVectorTime);
+ when(peerVectorTime.getCopy()).thenReturn(peerVectorTime);
+ when(peerPrefs.getString(anyString())).thenReturn("TestString");
+ when(peerProcess.getTime()).thenReturn(time);
+ when(peerProcess.getProcessID()).thenReturn(pid);
+ when(peerProcess.getRandomPercentage()).thenReturn(25);
+
+ peerProtocol.onServerInit();
+ clearInvocations(peerProcess, peerTaskManager);
+ return new LeaderHarness(peerProtocol, peerProcess);
+ }
+
+ private static final class CapturingRaftProtocol extends VSRaftProtocol {
+ private ArrayList<VSMessage> sentMessages = new ArrayList<VSMessage>();
+
+ @Override
+ public void sendMessage(VSMessage message) {
+ sentMessages.add(message);
+ }
+
+ private ArrayList<VSMessage> getSentMessages() {
+ return sentMessages;
+ }
+ }
+
+ private int getIntField(VSRaftProtocol raftProtocol, String fieldName)
+ throws Exception {
+ Field field = VSRaftProtocol.class.getDeclaredField(fieldName);
+ field.setAccessible(true);
+ return field.getInt(raftProtocol);
+ }
+
+ @SuppressWarnings("unchecked")
+ private ArrayList<Integer> getAckPids(VSRaftProtocol raftProtocol)
+ throws Exception {
+ Field field = VSRaftProtocol.class.getDeclaredField("ackPids");
+ field.setAccessible(true);
+ return (ArrayList<Integer>) field.get(raftProtocol);
+ }
+
+ private record LeaderHarness(
+ CapturingRaftProtocol protocol,
+ VSInternalProcess process
+ ) {
+ }
}