summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/java/protocols/implementations/VSRaftProtocol.java23
-rw-r--r--src/test/java/protocols/implementations/VSRaftProtocolTest.java129
2 files changed, 113 insertions, 39 deletions
diff --git a/src/main/java/protocols/implementations/VSRaftProtocol.java b/src/main/java/protocols/implementations/VSRaftProtocol.java
index 983c8d3..8c36b68 100644
--- a/src/main/java/protocols/implementations/VSRaftProtocol.java
+++ b/src/main/java/protocols/implementations/VSRaftProtocol.java
@@ -277,6 +277,8 @@ public class VSRaftProtocol extends VSAbstractProtocol {
* Sends a simplified append-entry request for the configured log entry.
*/
private void sendAppendEntry() {
+ ackPids.clear();
+
if (getVectorKeySet().contains("pids")) {
ackPids.addAll(getVector("pids"));
}
@@ -381,6 +383,7 @@ public class VSRaftProtocol extends VSAbstractProtocol {
private void handleAppendEntry(VSMessage recvMessage) {
int messageTerm = recvMessage.getInteger("term");
int messageLeaderId = recvMessage.getInteger("leaderId");
+ int messageLogIndex = recvMessage.getInteger("logIndex");
if (messageTerm > currentTerm) {
becomeFollower(messageTerm, messageLeaderId);
@@ -393,13 +396,17 @@ public class VSRaftProtocol extends VSAbstractProtocol {
return;
}
- logIndex++;
+ if (messageLogIndex != logIndex + 1) {
+ return;
+ }
+
+ logIndex = messageLogIndex;
VSMessage appendAck = new VSMessage();
appendAck.setString("type", "appendAck");
appendAck.setInteger("term", currentTerm);
appendAck.setInteger("pid", process.getProcessID());
- appendAck.setInteger("logIndex", logIndex);
+ appendAck.setInteger("logIndex", messageLogIndex);
appendAck.setInteger("targetPid", messageLeaderId);
sendMessage(appendAck);
}
@@ -410,17 +417,25 @@ public class VSRaftProtocol extends VSAbstractProtocol {
* @param recvMessage the append acknowledgement
*/
private void handleAppendAck(VSMessage recvMessage) {
+ int messageTerm = recvMessage.getInteger("term");
Integer responderPid = recvMessage.getIntegerObj("pid");
+ int ackLogIndex = recvMessage.getInteger("logIndex");
+
+ if (messageTerm > currentTerm) {
+ becomeFollower(messageTerm, -1);
+ return;
+ }
if (!isLeader || !isForMe(recvMessage) || responderPid == null ||
+ messageTerm != currentTerm || ackLogIndex != logIndex ||
!ackPids.contains(responderPid)) {
return;
}
ackPids.remove(responderPid);
- if (ackPids.isEmpty()) {
- commitIndex++;
+ if (ackPids.isEmpty() && commitIndex < ackLogIndex) {
+ commitIndex = ackLogIndex;
log("Committed log index " + commitIndex);
}
}
diff --git a/src/test/java/protocols/implementations/VSRaftProtocolTest.java b/src/test/java/protocols/implementations/VSRaftProtocolTest.java
index f3dc0d6..49f87f4 100644
--- a/src/test/java/protocols/implementations/VSRaftProtocolTest.java
+++ b/src/test/java/protocols/implementations/VSRaftProtocolTest.java
@@ -16,6 +16,7 @@ import simulator.VSSimulatorVisualization;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
+import java.util.ArrayList;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -71,7 +72,7 @@ class VSRaftProtocolTest {
}
@Test
- void testOnStartBecomesLeaderAndSendsHeartbeat() {
+ void testOnStartBecomesLeaderAndSendsHeartbeat() throws Exception {
ArgumentCaptor<VSMessage> messageCaptor =
ArgumentCaptor.forClass(VSMessage.class);
ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class);
@@ -91,6 +92,9 @@ class VSRaftProtocolTest {
assertEquals(7, appendEntry.getInteger("leaderId"));
assertEquals("cmd1", appendEntry.getString("entry"));
assertEquals(1, appendEntry.getInteger("logIndex"));
+ assertEquals(2, getAckPids().size());
+ assertTrue(getAckPids().contains(2));
+ assertTrue(getAckPids().contains(3));
assertEquals(1600L, taskCaptor.getValue().getTaskTime());
}
@@ -538,23 +542,63 @@ class VSRaftProtocolTest {
}
@Test
- void testAppendAckForLeaderCommitsOnceAllFollowersAck() throws Exception {
- setBooleanField("isLeader", true);
- setIntField("logIndex", 1);
- @SuppressWarnings("unchecked")
- java.util.ArrayList<Integer> ackPids =
- (java.util.ArrayList<Integer>) getObjectField("ackPids");
- ackPids.clear();
- ackPids.add(2);
-
- VSMessage appendAck = new VSMessage();
- appendAck.setString("type", "appendAck");
- appendAck.setInteger("term", 1);
- appendAck.setInteger("pid", 2);
- appendAck.setInteger("logIndex", 1);
- appendAck.setInteger("targetPid", 7);
-
- protocol.onServerRecv(appendAck);
+ void testAppendEntryOutOfSyncDoesNotAdvanceFollowerLogOrSendAck()
+ throws Exception {
+ protocol.currentContextIsServer(false);
+ protocol.onClientInit();
+ clearInvocations(mockProcess, mockTaskManager);
+ setIntField("currentTerm", 2);
+ when(mockProcess.getTime()).thenReturn(600L);
+
+ VSMessage appendEntry = new VSMessage();
+ appendEntry.setString("type", "appendEntry");
+ appendEntry.setInteger("term", 2);
+ appendEntry.setInteger("leaderId", 11);
+ appendEntry.setString("entry", "cmd2");
+ appendEntry.setInteger("logIndex", 2);
+
+ protocol.onClientRecv(appendEntry);
+
+ verify(mockProcess, never()).sendMessage(any());
+ verify(mockTaskManager).removeAllTasks(any());
+ verify(mockTaskManager).addTask(any());
+ assertEquals(0, getIntField("logIndex"));
+ assertEquals(2, getIntField("currentTerm"));
+ assertEquals(11, getIntField("leaderId"));
+ }
+
+ @Test
+ void testLeaderAppendQuorumStateDrainsAndCommitsAfterFollowerAcks()
+ throws Exception {
+ protocol.onStart();
+ clearInvocations(mockProcess, mockTaskManager);
+
+ ArrayList<Integer> ackPids = getAckPids();
+ assertEquals(2, ackPids.size());
+ assertEquals(1, getIntField("logIndex"));
+
+ VSMessage firstAck = new VSMessage();
+ firstAck.setString("type", "appendAck");
+ firstAck.setInteger("term", 0);
+ firstAck.setInteger("pid", 2);
+ firstAck.setInteger("logIndex", 1);
+ firstAck.setInteger("targetPid", 7);
+
+ protocol.onServerRecv(firstAck);
+
+ verify(mockProcess, never()).log(anyString());
+ assertEquals(1, ackPids.size());
+ assertTrue(ackPids.contains(3));
+ assertEquals(0, getIntField("commitIndex"));
+
+ VSMessage secondAck = new VSMessage();
+ secondAck.setString("type", "appendAck");
+ secondAck.setInteger("term", 0);
+ secondAck.setInteger("pid", 3);
+ secondAck.setInteger("logIndex", 1);
+ secondAck.setInteger("targetPid", 7);
+
+ protocol.onServerRecv(secondAck);
verify(mockProcess).log("Committed log index 1");
assertTrue(ackPids.isEmpty());
@@ -562,25 +606,35 @@ class VSRaftProtocolTest {
}
@Test
- void testAppendAckForDifferentLeaderTargetDoesNothing() throws Exception {
- setBooleanField("isLeader", true);
- @SuppressWarnings("unchecked")
- java.util.ArrayList<Integer> ackPids =
- (java.util.ArrayList<Integer>) getObjectField("ackPids");
- ackPids.clear();
- ackPids.add(2);
-
- VSMessage appendAck = new VSMessage();
- appendAck.setString("type", "appendAck");
- appendAck.setInteger("term", 1);
- appendAck.setInteger("pid", 2);
- appendAck.setInteger("logIndex", 1);
- appendAck.setInteger("targetPid", 99);
-
- protocol.onServerRecv(appendAck);
+ void testAppendAckWithWrongTermOrLogIndexDoesNotDrainLeaderQuorum()
+ throws Exception {
+ protocol.onStart();
+ clearInvocations(mockProcess, mockTaskManager);
+
+ ArrayList<Integer> ackPids = getAckPids();
+
+ VSMessage wrongTermAck = new VSMessage();
+ wrongTermAck.setString("type", "appendAck");
+ wrongTermAck.setInteger("term", -1);
+ wrongTermAck.setInteger("pid", 2);
+ wrongTermAck.setInteger("logIndex", 1);
+ wrongTermAck.setInteger("targetPid", 7);
+
+ protocol.onServerRecv(wrongTermAck);
+
+ VSMessage wrongIndexAck = new VSMessage();
+ wrongIndexAck.setString("type", "appendAck");
+ wrongIndexAck.setInteger("term", 0);
+ wrongIndexAck.setInteger("pid", 2);
+ wrongIndexAck.setInteger("logIndex", 2);
+ wrongIndexAck.setInteger("targetPid", 7);
+
+ protocol.onServerRecv(wrongIndexAck);
verify(mockProcess, never()).log(anyString());
- assertEquals(1, ackPids.size());
+ assertEquals(2, ackPids.size());
+ assertTrue(ackPids.contains(2));
+ assertTrue(ackPids.contains(3));
assertEquals(0, getIntField("commitIndex"));
}
@@ -616,6 +670,11 @@ class VSRaftProtocolTest {
return field.get(protocol);
}
+ @SuppressWarnings("unchecked")
+ private ArrayList<Integer> getAckPids() throws Exception {
+ return (ArrayList<Integer>) getObjectField("ackPids");
+ }
+
private boolean getBooleanField(String fieldName) throws Exception {
Field field = VSRaftProtocol.class.getDeclaredField(fieldName);
field.setAccessible(true);