summaryrefslogtreecommitdiff
path: root/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/test')
-rw-r--r--src/test/java/protocols/implementations/VSRaftProtocolTest.java137
1 files changed, 109 insertions, 28 deletions
diff --git a/src/test/java/protocols/implementations/VSRaftProtocolTest.java b/src/test/java/protocols/implementations/VSRaftProtocolTest.java
index 49f87f4..c9440b1 100644
--- a/src/test/java/protocols/implementations/VSRaftProtocolTest.java
+++ b/src/test/java/protocols/implementations/VSRaftProtocolTest.java
@@ -24,6 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.clearInvocations;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -524,8 +525,8 @@ class VSRaftProtocolTest {
protocol.onClientRecv(appendEntry);
verify(mockProcess).sendMessage(messageCaptor.capture());
- verify(mockTaskManager, times(2)).removeAllTasks(any());
- verify(mockTaskManager).addTask(taskCaptor.capture());
+ verify(mockTaskManager, times(3)).removeAllTasks(any());
+ verify(mockTaskManager, times(2)).addTask(taskCaptor.capture());
VSMessage appendAck = messageCaptor.getValue();
assertEquals("appendAck", appendAck.getString("type"));
@@ -548,6 +549,9 @@ class VSRaftProtocolTest {
protocol.onClientInit();
clearInvocations(mockProcess, mockTaskManager);
setIntField("currentTerm", 2);
+ setIntField("leaderId", 5);
+ setBooleanField("isCandidate", true);
+ long electionDeadline = getLongField("electionDeadline");
when(mockProcess.getTime()).thenReturn(600L);
VSMessage appendEntry = new VSMessage();
@@ -560,49 +564,63 @@ class VSRaftProtocolTest {
protocol.onClientRecv(appendEntry);
verify(mockProcess, never()).sendMessage(any());
- verify(mockTaskManager).removeAllTasks(any());
- verify(mockTaskManager).addTask(any());
+ verify(mockTaskManager, never()).removeAllTasks(any());
+ verify(mockTaskManager, never()).addTask(any());
assertEquals(0, getIntField("logIndex"));
assertEquals(2, getIntField("currentTerm"));
- assertEquals(11, getIntField("leaderId"));
+ assertEquals(5, getIntField("leaderId"));
+ assertTrue(getBooleanField("isCandidate"));
+ assertEquals(electionDeadline, getLongField("electionDeadline"));
}
@Test
- void testLeaderAppendQuorumStateDrainsAndCommitsAfterFollowerAcks()
+ void testLeaderAppendQuorumStateDrainsAndCommitsAfterFollowerRoundTrips()
throws Exception {
- protocol.onStart();
- clearInvocations(mockProcess, mockTaskManager);
+ LeaderHarness leaderHarness = createLeaderHarness(11, 300L);
+ leaderHarness.protocol.onStart();
- ArrayList<Integer> ackPids = getAckPids();
+ ArrayList<VSMessage> sentMessages = leaderHarness.protocol.getSentMessages();
+ assertEquals(2, sentMessages.size());
+ VSMessage appendEntry = sentMessages.get(1);
+ ArrayList<Integer> ackPids = getAckPids(leaderHarness.protocol);
assertEquals(2, ackPids.size());
- assertEquals(1, getIntField("logIndex"));
+ assertEquals(1, getIntField(leaderHarness.protocol, "logIndex"));
- VSMessage firstAck = new VSMessage();
- firstAck.setString("type", "appendAck");
- firstAck.setInteger("term", 0);
- firstAck.setInteger("pid", 2);
- firstAck.setInteger("logIndex", 1);
- firstAck.setInteger("targetPid", 7);
+ protocol.currentContextIsServer(false);
+ protocol.onClientInit();
+ setIntField("currentTerm", -1);
+ clearInvocations(mockProcess, mockTaskManager);
+ when(mockProcess.getProcessID()).thenReturn(2);
+ when(mockProcess.getTime()).thenReturn(700L, 700L);
- protocol.onServerRecv(firstAck);
+ ArgumentCaptor<VSMessage> followerAckCaptor =
+ ArgumentCaptor.forClass(VSMessage.class);
+ protocol.onClientRecv(appendEntry);
+ verify(mockProcess).sendMessage(followerAckCaptor.capture());
- verify(mockProcess, never()).log(anyString());
+ leaderHarness.protocol.onServerRecv(followerAckCaptor.getValue());
+
+ verify(leaderHarness.process, never()).log(anyString());
assertEquals(1, ackPids.size());
assertTrue(ackPids.contains(3));
- assertEquals(0, getIntField("commitIndex"));
+ assertEquals(0, getIntField(leaderHarness.protocol, "commitIndex"));
+
+ protocol.onClientReset();
+ protocol.currentContextIsServer(false);
+ protocol.onClientInit();
+ setIntField("currentTerm", -1);
+ clearInvocations(mockProcess, mockTaskManager);
+ when(mockProcess.getProcessID()).thenReturn(3);
+ when(mockProcess.getTime()).thenReturn(800L, 800L);
- VSMessage secondAck = new VSMessage();
- secondAck.setString("type", "appendAck");
- secondAck.setInteger("term", 0);
- secondAck.setInteger("pid", 3);
- secondAck.setInteger("logIndex", 1);
- secondAck.setInteger("targetPid", 7);
+ protocol.onClientRecv(appendEntry);
+ verify(mockProcess).sendMessage(followerAckCaptor.capture());
- protocol.onServerRecv(secondAck);
+ leaderHarness.protocol.onServerRecv(followerAckCaptor.getAllValues().get(1));
- verify(mockProcess).log("Committed log index 1");
+ verify(leaderHarness.process).log("Committed log index 1");
assertTrue(ackPids.isEmpty());
- assertEquals(1, getIntField("commitIndex"));
+ assertEquals(1, getIntField(leaderHarness.protocol, "commitIndex"));
}
@Test
@@ -686,4 +704,67 @@ class VSRaftProtocolTest {
field.setAccessible(true);
return field.getLong(protocol);
}
+
+ private LeaderHarness createLeaderHarness(int pid, long time) {
+ CapturingRaftProtocol peerProtocol = new CapturingRaftProtocol();
+ VSInternalProcess peerProcess = mock(VSInternalProcess.class);
+ VSSimulatorVisualization peerCanvas = mock(VSSimulatorVisualization.class);
+ VSTaskManager peerTaskManager = mock(VSTaskManager.class);
+ VSPrefs peerPrefs = mock(VSPrefs.class);
+ VSVectorTime peerVectorTime = mock(VSVectorTime.class);
+
+ peerProtocol.process = peerProcess;
+ peerProtocol.prefs = peerPrefs;
+ peerProtocol.isServer(true);
+ peerProtocol.currentContextIsServer(true);
+
+ when(peerProcess.getSimulatorCanvas()).thenReturn(peerCanvas);
+ when(peerCanvas.getTaskManager()).thenReturn(peerTaskManager);
+ when(peerCanvas.getNumProcesses()).thenReturn(3);
+ when(peerProcess.getPrefs()).thenReturn(peerPrefs);
+ when(peerProcess.getVectorTime()).thenReturn(peerVectorTime);
+ when(peerVectorTime.getCopy()).thenReturn(peerVectorTime);
+ when(peerPrefs.getString(anyString())).thenReturn("TestString");
+ when(peerProcess.getTime()).thenReturn(time);
+ when(peerProcess.getProcessID()).thenReturn(pid);
+ when(peerProcess.getRandomPercentage()).thenReturn(25);
+
+ peerProtocol.onServerInit();
+ clearInvocations(peerProcess, peerTaskManager);
+ return new LeaderHarness(peerProtocol, peerProcess);
+ }
+
+ private static final class CapturingRaftProtocol extends VSRaftProtocol {
+ private ArrayList<VSMessage> sentMessages = new ArrayList<VSMessage>();
+
+ @Override
+ public void sendMessage(VSMessage message) {
+ sentMessages.add(message);
+ }
+
+ private ArrayList<VSMessage> getSentMessages() {
+ return sentMessages;
+ }
+ }
+
+ private int getIntField(VSRaftProtocol raftProtocol, String fieldName)
+ throws Exception {
+ Field field = VSRaftProtocol.class.getDeclaredField(fieldName);
+ field.setAccessible(true);
+ return field.getInt(raftProtocol);
+ }
+
+ @SuppressWarnings("unchecked")
+ private ArrayList<Integer> getAckPids(VSRaftProtocol raftProtocol)
+ throws Exception {
+ Field field = VSRaftProtocol.class.getDeclaredField("ackPids");
+ field.setAccessible(true);
+ return (ArrayList<Integer>) field.get(raftProtocol);
+ }
+
+ private record LeaderHarness(
+ CapturingRaftProtocol protocol,
+ VSInternalProcess process
+ ) {
+ }
}