diff options
| -rw-r--r-- | src/main/java/protocols/implementations/VSRaftProtocol.java | 23 | ||||
| -rw-r--r-- | src/test/java/protocols/implementations/VSRaftProtocolTest.java | 129 |
2 files changed, 113 insertions, 39 deletions
diff --git a/src/main/java/protocols/implementations/VSRaftProtocol.java b/src/main/java/protocols/implementations/VSRaftProtocol.java index 983c8d3..8c36b68 100644 --- a/src/main/java/protocols/implementations/VSRaftProtocol.java +++ b/src/main/java/protocols/implementations/VSRaftProtocol.java @@ -277,6 +277,8 @@ public class VSRaftProtocol extends VSAbstractProtocol { * Sends a simplified append-entry request for the configured log entry. */ private void sendAppendEntry() { + ackPids.clear(); + if (getVectorKeySet().contains("pids")) { ackPids.addAll(getVector("pids")); } @@ -381,6 +383,7 @@ public class VSRaftProtocol extends VSAbstractProtocol { private void handleAppendEntry(VSMessage recvMessage) { int messageTerm = recvMessage.getInteger("term"); int messageLeaderId = recvMessage.getInteger("leaderId"); + int messageLogIndex = recvMessage.getInteger("logIndex"); if (messageTerm > currentTerm) { becomeFollower(messageTerm, messageLeaderId); @@ -393,13 +396,17 @@ public class VSRaftProtocol extends VSAbstractProtocol { return; } - logIndex++; + if (messageLogIndex != logIndex + 1) { + return; + } + + logIndex = messageLogIndex; VSMessage appendAck = new VSMessage(); appendAck.setString("type", "appendAck"); appendAck.setInteger("term", currentTerm); appendAck.setInteger("pid", process.getProcessID()); - appendAck.setInteger("logIndex", logIndex); + appendAck.setInteger("logIndex", messageLogIndex); appendAck.setInteger("targetPid", messageLeaderId); sendMessage(appendAck); } @@ -410,17 +417,25 @@ public class VSRaftProtocol extends VSAbstractProtocol { * @param recvMessage the append acknowledgement */ private void handleAppendAck(VSMessage recvMessage) { + int messageTerm = recvMessage.getInteger("term"); Integer responderPid = recvMessage.getIntegerObj("pid"); + int ackLogIndex = recvMessage.getInteger("logIndex"); + + if (messageTerm > currentTerm) { + becomeFollower(messageTerm, -1); + return; + } if (!isLeader || !isForMe(recvMessage) || responderPid == null || + messageTerm != currentTerm || ackLogIndex != logIndex || !ackPids.contains(responderPid)) { return; } ackPids.remove(responderPid); - if (ackPids.isEmpty()) { - commitIndex++; + if (ackPids.isEmpty() && commitIndex < ackLogIndex) { + commitIndex = ackLogIndex; log("Committed log index " + commitIndex); } } diff --git a/src/test/java/protocols/implementations/VSRaftProtocolTest.java b/src/test/java/protocols/implementations/VSRaftProtocolTest.java index f3dc0d6..49f87f4 100644 --- a/src/test/java/protocols/implementations/VSRaftProtocolTest.java +++ b/src/test/java/protocols/implementations/VSRaftProtocolTest.java @@ -16,6 +16,7 @@ import simulator.VSSimulatorVisualization; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.util.ArrayList; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -71,7 +72,7 @@ class VSRaftProtocolTest { } @Test - void testOnStartBecomesLeaderAndSendsHeartbeat() { + void testOnStartBecomesLeaderAndSendsHeartbeat() throws Exception { ArgumentCaptor<VSMessage> messageCaptor = ArgumentCaptor.forClass(VSMessage.class); ArgumentCaptor<VSTask> taskCaptor = ArgumentCaptor.forClass(VSTask.class); @@ -91,6 +92,9 @@ class VSRaftProtocolTest { assertEquals(7, appendEntry.getInteger("leaderId")); assertEquals("cmd1", appendEntry.getString("entry")); assertEquals(1, appendEntry.getInteger("logIndex")); + assertEquals(2, getAckPids().size()); + assertTrue(getAckPids().contains(2)); + assertTrue(getAckPids().contains(3)); assertEquals(1600L, taskCaptor.getValue().getTaskTime()); } @@ -538,23 +542,63 @@ class VSRaftProtocolTest { } @Test - void testAppendAckForLeaderCommitsOnceAllFollowersAck() throws Exception { - setBooleanField("isLeader", true); - setIntField("logIndex", 1); - @SuppressWarnings("unchecked") - java.util.ArrayList<Integer> ackPids = - (java.util.ArrayList<Integer>) getObjectField("ackPids"); - ackPids.clear(); - ackPids.add(2); - - VSMessage appendAck = new VSMessage(); - appendAck.setString("type", "appendAck"); - appendAck.setInteger("term", 1); - appendAck.setInteger("pid", 2); - appendAck.setInteger("logIndex", 1); - appendAck.setInteger("targetPid", 7); - - protocol.onServerRecv(appendAck); + void testAppendEntryOutOfSyncDoesNotAdvanceFollowerLogOrSendAck() + throws Exception { + protocol.currentContextIsServer(false); + protocol.onClientInit(); + clearInvocations(mockProcess, mockTaskManager); + setIntField("currentTerm", 2); + when(mockProcess.getTime()).thenReturn(600L); + + VSMessage appendEntry = new VSMessage(); + appendEntry.setString("type", "appendEntry"); + appendEntry.setInteger("term", 2); + appendEntry.setInteger("leaderId", 11); + appendEntry.setString("entry", "cmd2"); + appendEntry.setInteger("logIndex", 2); + + protocol.onClientRecv(appendEntry); + + verify(mockProcess, never()).sendMessage(any()); + verify(mockTaskManager).removeAllTasks(any()); + verify(mockTaskManager).addTask(any()); + assertEquals(0, getIntField("logIndex")); + assertEquals(2, getIntField("currentTerm")); + assertEquals(11, getIntField("leaderId")); + } + + @Test + void testLeaderAppendQuorumStateDrainsAndCommitsAfterFollowerAcks() + throws Exception { + protocol.onStart(); + clearInvocations(mockProcess, mockTaskManager); + + ArrayList<Integer> ackPids = getAckPids(); + assertEquals(2, ackPids.size()); + assertEquals(1, getIntField("logIndex")); + + VSMessage firstAck = new VSMessage(); + firstAck.setString("type", "appendAck"); + firstAck.setInteger("term", 0); + firstAck.setInteger("pid", 2); + firstAck.setInteger("logIndex", 1); + firstAck.setInteger("targetPid", 7); + + protocol.onServerRecv(firstAck); + + verify(mockProcess, never()).log(anyString()); + assertEquals(1, ackPids.size()); + assertTrue(ackPids.contains(3)); + assertEquals(0, getIntField("commitIndex")); + + VSMessage secondAck = new VSMessage(); + secondAck.setString("type", "appendAck"); + secondAck.setInteger("term", 0); + secondAck.setInteger("pid", 3); + secondAck.setInteger("logIndex", 1); + secondAck.setInteger("targetPid", 7); + + protocol.onServerRecv(secondAck); verify(mockProcess).log("Committed log index 1"); assertTrue(ackPids.isEmpty()); @@ -562,25 +606,35 @@ class VSRaftProtocolTest { } @Test - void testAppendAckForDifferentLeaderTargetDoesNothing() throws Exception { - setBooleanField("isLeader", true); - @SuppressWarnings("unchecked") - java.util.ArrayList<Integer> ackPids = - (java.util.ArrayList<Integer>) getObjectField("ackPids"); - ackPids.clear(); - ackPids.add(2); - - VSMessage appendAck = new VSMessage(); - appendAck.setString("type", "appendAck"); - appendAck.setInteger("term", 1); - appendAck.setInteger("pid", 2); - appendAck.setInteger("logIndex", 1); - appendAck.setInteger("targetPid", 99); - - protocol.onServerRecv(appendAck); + void testAppendAckWithWrongTermOrLogIndexDoesNotDrainLeaderQuorum() + throws Exception { + protocol.onStart(); + clearInvocations(mockProcess, mockTaskManager); + + ArrayList<Integer> ackPids = getAckPids(); + + VSMessage wrongTermAck = new VSMessage(); + wrongTermAck.setString("type", "appendAck"); + wrongTermAck.setInteger("term", -1); + wrongTermAck.setInteger("pid", 2); + wrongTermAck.setInteger("logIndex", 1); + wrongTermAck.setInteger("targetPid", 7); + + protocol.onServerRecv(wrongTermAck); + + VSMessage wrongIndexAck = new VSMessage(); + wrongIndexAck.setString("type", "appendAck"); + wrongIndexAck.setInteger("term", 0); + wrongIndexAck.setInteger("pid", 2); + wrongIndexAck.setInteger("logIndex", 2); + wrongIndexAck.setInteger("targetPid", 7); + + protocol.onServerRecv(wrongIndexAck); verify(mockProcess, never()).log(anyString()); - assertEquals(1, ackPids.size()); + assertEquals(2, ackPids.size()); + assertTrue(ackPids.contains(2)); + assertTrue(ackPids.contains(3)); assertEquals(0, getIntField("commitIndex")); } @@ -616,6 +670,11 @@ class VSRaftProtocolTest { return field.get(protocol); } + @SuppressWarnings("unchecked") + private ArrayList<Integer> getAckPids() throws Exception { + return (ArrayList<Integer>) getObjectField("ackPids"); + } + private boolean getBooleanField(String fieldName) throws Exception { Field field = VSRaftProtocol.class.getDeclaredField(fieldName); field.setAccessible(true); |
