From f12114c1d9ec50f20f3df3e9c6e335e00f186c10 Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Fri, 27 Mar 2026 06:19:38 +0200 Subject: Fix final Raft append review issues b85586a4-4eb9-4686-93c7-0ab14173baa5 --- .../protocols/implementations/VSRaftProtocol.java | 12 +- .../implementations/VSRaftProtocolTest.java | 137 ++++++++++++++++----- 2 files changed, 116 insertions(+), 33 deletions(-) diff --git a/src/main/java/protocols/implementations/VSRaftProtocol.java b/src/main/java/protocols/implementations/VSRaftProtocol.java index 8c36b68..d0066e0 100644 --- a/src/main/java/protocols/implementations/VSRaftProtocol.java +++ b/src/main/java/protocols/implementations/VSRaftProtocol.java @@ -387,11 +387,6 @@ public class VSRaftProtocol extends VSAbstractProtocol { if (messageTerm > currentTerm) { becomeFollower(messageTerm, messageLeaderId); - } else if (messageTerm == currentTerm) { - leaderId = messageLeaderId; - isLeader = false; - isCandidate = false; - resetElectionTimeout(); } else { return; } @@ -400,6 +395,13 @@ public class VSRaftProtocol extends VSAbstractProtocol { return; } + if (messageTerm == currentTerm) { + leaderId = messageLeaderId; + isLeader = false; + isCandidate = false; + resetElectionTimeout(); + } + logIndex = messageLogIndex; VSMessage appendAck = new VSMessage(); diff --git a/src/test/java/protocols/implementations/VSRaftProtocolTest.java b/src/test/java/protocols/implementations/VSRaftProtocolTest.java index 49f87f4..c9440b1 100644 --- a/src/test/java/protocols/implementations/VSRaftProtocolTest.java +++ b/src/test/java/protocols/implementations/VSRaftProtocolTest.java @@ -24,6 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -524,8 +525,8 @@ class VSRaftProtocolTest { protocol.onClientRecv(appendEntry); verify(mockProcess).sendMessage(messageCaptor.capture()); - verify(mockTaskManager, times(2)).removeAllTasks(any()); - verify(mockTaskManager).addTask(taskCaptor.capture()); + verify(mockTaskManager, times(3)).removeAllTasks(any()); + verify(mockTaskManager, times(2)).addTask(taskCaptor.capture()); VSMessage appendAck = messageCaptor.getValue(); assertEquals("appendAck", appendAck.getString("type")); @@ -548,6 +549,9 @@ class VSRaftProtocolTest { protocol.onClientInit(); clearInvocations(mockProcess, mockTaskManager); setIntField("currentTerm", 2); + setIntField("leaderId", 5); + setBooleanField("isCandidate", true); + long electionDeadline = getLongField("electionDeadline"); when(mockProcess.getTime()).thenReturn(600L); VSMessage appendEntry = new VSMessage(); @@ -560,49 +564,63 @@ class VSRaftProtocolTest { protocol.onClientRecv(appendEntry); verify(mockProcess, never()).sendMessage(any()); - verify(mockTaskManager).removeAllTasks(any()); - verify(mockTaskManager).addTask(any()); + verify(mockTaskManager, never()).removeAllTasks(any()); + verify(mockTaskManager, never()).addTask(any()); assertEquals(0, getIntField("logIndex")); assertEquals(2, getIntField("currentTerm")); - assertEquals(11, getIntField("leaderId")); + assertEquals(5, getIntField("leaderId")); + assertTrue(getBooleanField("isCandidate")); + assertEquals(electionDeadline, getLongField("electionDeadline")); } @Test - void testLeaderAppendQuorumStateDrainsAndCommitsAfterFollowerAcks() + void testLeaderAppendQuorumStateDrainsAndCommitsAfterFollowerRoundTrips() throws Exception { - protocol.onStart(); - clearInvocations(mockProcess, mockTaskManager); + LeaderHarness leaderHarness = createLeaderHarness(11, 300L); + leaderHarness.protocol.onStart(); - ArrayList ackPids = getAckPids(); + ArrayList sentMessages = leaderHarness.protocol.getSentMessages(); + assertEquals(2, sentMessages.size()); + VSMessage appendEntry = sentMessages.get(1); + ArrayList ackPids = getAckPids(leaderHarness.protocol); assertEquals(2, ackPids.size()); - assertEquals(1, getIntField("logIndex")); + assertEquals(1, getIntField(leaderHarness.protocol, "logIndex")); - VSMessage firstAck = new VSMessage(); - firstAck.setString("type", "appendAck"); - firstAck.setInteger("term", 0); - firstAck.setInteger("pid", 2); - firstAck.setInteger("logIndex", 1); - firstAck.setInteger("targetPid", 7); + protocol.currentContextIsServer(false); + protocol.onClientInit(); + setIntField("currentTerm", -1); + clearInvocations(mockProcess, mockTaskManager); + when(mockProcess.getProcessID()).thenReturn(2); + when(mockProcess.getTime()).thenReturn(700L, 700L); - protocol.onServerRecv(firstAck); + ArgumentCaptor followerAckCaptor = + ArgumentCaptor.forClass(VSMessage.class); + protocol.onClientRecv(appendEntry); + verify(mockProcess).sendMessage(followerAckCaptor.capture()); - verify(mockProcess, never()).log(anyString()); + leaderHarness.protocol.onServerRecv(followerAckCaptor.getValue()); + + verify(leaderHarness.process, never()).log(anyString()); assertEquals(1, ackPids.size()); assertTrue(ackPids.contains(3)); - assertEquals(0, getIntField("commitIndex")); + assertEquals(0, getIntField(leaderHarness.protocol, "commitIndex")); + + protocol.onClientReset(); + protocol.currentContextIsServer(false); + protocol.onClientInit(); + setIntField("currentTerm", -1); + clearInvocations(mockProcess, mockTaskManager); + when(mockProcess.getProcessID()).thenReturn(3); + when(mockProcess.getTime()).thenReturn(800L, 800L); - VSMessage secondAck = new VSMessage(); - secondAck.setString("type", "appendAck"); - secondAck.setInteger("term", 0); - secondAck.setInteger("pid", 3); - secondAck.setInteger("logIndex", 1); - secondAck.setInteger("targetPid", 7); + protocol.onClientRecv(appendEntry); + verify(mockProcess).sendMessage(followerAckCaptor.capture()); - protocol.onServerRecv(secondAck); + leaderHarness.protocol.onServerRecv(followerAckCaptor.getAllValues().get(1)); - verify(mockProcess).log("Committed log index 1"); + verify(leaderHarness.process).log("Committed log index 1"); assertTrue(ackPids.isEmpty()); - assertEquals(1, getIntField("commitIndex")); + assertEquals(1, getIntField(leaderHarness.protocol, "commitIndex")); } @Test @@ -686,4 +704,67 @@ class VSRaftProtocolTest { field.setAccessible(true); return field.getLong(protocol); } + + private LeaderHarness createLeaderHarness(int pid, long time) { + CapturingRaftProtocol peerProtocol = new CapturingRaftProtocol(); + VSInternalProcess peerProcess = mock(VSInternalProcess.class); + VSSimulatorVisualization peerCanvas = mock(VSSimulatorVisualization.class); + VSTaskManager peerTaskManager = mock(VSTaskManager.class); + VSPrefs peerPrefs = mock(VSPrefs.class); + VSVectorTime peerVectorTime = mock(VSVectorTime.class); + + peerProtocol.process = peerProcess; + peerProtocol.prefs = peerPrefs; + peerProtocol.isServer(true); + peerProtocol.currentContextIsServer(true); + + when(peerProcess.getSimulatorCanvas()).thenReturn(peerCanvas); + when(peerCanvas.getTaskManager()).thenReturn(peerTaskManager); + when(peerCanvas.getNumProcesses()).thenReturn(3); + when(peerProcess.getPrefs()).thenReturn(peerPrefs); + when(peerProcess.getVectorTime()).thenReturn(peerVectorTime); + when(peerVectorTime.getCopy()).thenReturn(peerVectorTime); + when(peerPrefs.getString(anyString())).thenReturn("TestString"); + when(peerProcess.getTime()).thenReturn(time); + when(peerProcess.getProcessID()).thenReturn(pid); + when(peerProcess.getRandomPercentage()).thenReturn(25); + + peerProtocol.onServerInit(); + clearInvocations(peerProcess, peerTaskManager); + return new LeaderHarness(peerProtocol, peerProcess); + } + + private static final class CapturingRaftProtocol extends VSRaftProtocol { + private ArrayList sentMessages = new ArrayList(); + + @Override + public void sendMessage(VSMessage message) { + sentMessages.add(message); + } + + private ArrayList getSentMessages() { + return sentMessages; + } + } + + private int getIntField(VSRaftProtocol raftProtocol, String fieldName) + throws Exception { + Field field = VSRaftProtocol.class.getDeclaredField(fieldName); + field.setAccessible(true); + return field.getInt(raftProtocol); + } + + @SuppressWarnings("unchecked") + private ArrayList getAckPids(VSRaftProtocol raftProtocol) + throws Exception { + Field field = VSRaftProtocol.class.getDeclaredField("ackPids"); + field.setAccessible(true); + return (ArrayList) field.get(raftProtocol); + } + + private record LeaderHarness( + CapturingRaftProtocol protocol, + VSInternalProcess process + ) { + } } -- cgit v1.2.3