package protocols.implementations;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import core.VSMessage;
import core.VSInternalProcess;
import protocols.VSAbstractProtocol;
/**
* Implementation of the Raft consensus algorithm.
*
* Raft is a consensus algorithm designed to be understandable. It ensures that
* a distributed system agrees on values even in the presence of failures.
*
*
The protocol has three states:
*
* - Follower - Passive state, responds to leaders
* - Candidate - Actively requesting votes to become leader
* - Leader - Manages the cluster and log replication
*
*
* Key features implemented:
*
* - Leader election with randomized timeouts
* - Log replication for state machine commands
* - Safety through term numbers and log matching
* - Membership changes (simplified)
*
*
* @author Paul C. Buetow
*/
public class VSRaftProtocol extends VSAbstractProtocol {
// Raft states
private enum State {
FOLLOWER,
CANDIDATE,
LEADER
}
// Message types
private static final String MSG_REQUEST_VOTE = "REQUEST_VOTE";
private static final String MSG_VOTE_RESPONSE = "VOTE_RESPONSE";
private static final String MSG_APPEND_ENTRIES = "APPEND_ENTRIES";
private static final String MSG_APPEND_RESPONSE = "APPEND_RESPONSE";
private static final String MSG_CLIENT_REQUEST = "CLIENT_REQUEST";
// Timing constants (in simulation time units)
private static final long HEARTBEAT_INTERVAL = 50;
private static final long ELECTION_TIMEOUT_MIN = 150;
private static final long ELECTION_TIMEOUT_MAX = 300;
// Server state (persistent - should be saved to stable storage)
private State currentState;
private int currentTerm;
private Integer votedFor;
private List log;
// Server state (volatile)
private int commitIndex;
private int lastApplied;
// Leader state (reinitialized after election)
private Map nextIndex;
private Map matchIndex;
// Candidate state
private Set votesReceived;
private long electionTimeout;
// General state
private Integer currentLeader;
private long lastHeartbeat;
// Client state
private boolean clientHasScheduled = false;
private int clientRequestCount = 0;
/**
* Log entry structure
*/
private static class LogEntry {
final int term;
final String command;
final long timestamp;
LogEntry(int term, String command, long timestamp) {
this.term = term;
this.command = command;
this.timestamp = timestamp;
}
@Override
public String toString() {
return String.format("LogEntry{term=%d, cmd='%s', time=%d}",
term, command, timestamp);
}
}
public VSRaftProtocol() {
super(VSAbstractProtocol.HAS_ON_SERVER_START);
}
@Override
public void onServerInit() {
currentState = State.FOLLOWER;
currentTerm = 0;
votedFor = null;
log = new ArrayList<>();
commitIndex = 0;
lastApplied = 0;
nextIndex = new ConcurrentHashMap<>();
matchIndex = new ConcurrentHashMap<>();
votesReceived = new HashSet<>();
currentLeader = null;
// Add a dummy entry at index 0 for easier indexing
log.add(new LogEntry(0, "INIT", 0));
}
@Override
public void onServerStart() {
// Initialize election timeout and start
resetElectionTimeout();
raftLog("Raft node initialized as FOLLOWER");
scheduleElectionTimeout();
}
@Override
public void onServerReset() {
onServerInit();
removeSchedules();
}
@Override
public void onServerRecv(VSMessage message) {
String msgType = message.getString("type");
int term = message.getInteger("term");
int senderId = message.getSendingProcess().getProcessNum();
// If we receive a message with a higher term, become follower
if (term > currentTerm) {
currentTerm = term;
votedFor = null;
if (currentState != State.FOLLOWER) {
becomeFollower();
}
}
switch (msgType) {
case MSG_REQUEST_VOTE:
handleRequestVote(message, senderId);
break;
case MSG_VOTE_RESPONSE:
handleVoteResponse(message, senderId);
break;
case MSG_APPEND_ENTRIES:
handleAppendEntries(message, senderId);
break;
case MSG_APPEND_RESPONSE:
handleAppendResponse(message, senderId);
break;
case MSG_CLIENT_REQUEST:
handleClientRequest(message, senderId);
break;
}
}
@Override
public void onServerSchedule() {
long currentTime = process.getTime();
switch (currentState) {
case FOLLOWER:
case CANDIDATE:
// Check election timeout
if (currentTime >= electionTimeout) {
startElection();
}
break;
case LEADER:
// Send heartbeats
sendHeartbeats();
scheduleAt(currentTime + HEARTBEAT_INTERVAL);
break;
}
}
@Override
public void onClientInit() {
// Initialize client state
clientHasScheduled = false;
clientRequestCount = 0;
}
@Override
public void onClientStart() {
// This method is never called when using HAS_ON_SERVER_START
// Clients will send requests in response to server heartbeats instead
}
@Override
public void onClientReset() {
removeSchedules();
clientHasScheduled = false;
clientRequestCount = 0;
}
@Override
public void onClientRecv(VSMessage message) {
// Clients can receive responses to their requests
String msgType = message.getString("type");
if ("CLIENT_RESPONSE".equals(msgType)) {
boolean success = message.getBoolean("success");
String result = message.getString("result");
raftLog("Client received response: success=" + success + ", result=" + result);
} else if (MSG_APPEND_ENTRIES.equals(msgType)) {
// Client receives heartbeat from leader - good time to send a request
if (!clientHasScheduled) {
clientHasScheduled = true;
// Schedule first client request after a short delay
scheduleAt(process.getTime() + 100);
}
}
}
@Override
public void onClientSchedule() {
// Send a test client request
VSMessage request = new VSMessage();
request.setString("type", MSG_CLIENT_REQUEST);
request.setString("command", "SET x=" + process.getRandomPercentage());
request.setLong("clientId", process.getProcessNum());
request.setLong("requestId", System.currentTimeMillis());
sendMessage(request);
raftLog("Client sent request #" + clientRequestCount + ": " + request.getString("command"));
// Update request count
clientRequestCount++;
// Schedule next request after a delay
if (clientRequestCount < 10) { // Limit number of requests for testing
scheduleAt(process.getTime() + 1000 + process.getRandomPercentage() * 10);
}
}
// --- Raft Algorithm Implementation ---
private void startElection() {
currentState = State.CANDIDATE;
currentTerm++;
votedFor = process.getProcessNum();
votesReceived.clear();
votesReceived.add(process.getProcessNum()); // Vote for self
raftLog("Starting election for term " + currentTerm);
// Send RequestVote to all other servers
VSMessage voteRequest = new VSMessage();
voteRequest.setString("type", MSG_REQUEST_VOTE);
voteRequest.setInteger("term", currentTerm);
voteRequest.setInteger("candidateId", process.getProcessNum());
voteRequest.setInteger("lastLogIndex", log.size() - 1);
voteRequest.setInteger("lastLogTerm", log.get(log.size() - 1).term);
sendMessage(voteRequest);
// Reset election timeout
resetElectionTimeout();
scheduleElectionTimeout();
}
private void handleRequestVote(VSMessage message, int candidateId) {
int term = message.getInteger("term");
int lastLogIndex = message.getInteger("lastLogIndex");
int lastLogTerm = message.getInteger("lastLogTerm");
boolean voteGranted = false;
// Grant vote if:
// 1. We haven't voted in this term or voted for this candidate
// 2. Candidate's log is at least as up-to-date as ours
if (term >= currentTerm &&
(votedFor == null || votedFor == candidateId) &&
isLogUpToDate(lastLogIndex, lastLogTerm)) {
votedFor = candidateId;
voteGranted = true;
resetElectionTimeout();
raftLog("Voted for candidate " + candidateId + " in term " + term);
}
// Send vote response
VSMessage response = new VSMessage();
response.setString("type", MSG_VOTE_RESPONSE);
response.setInteger("term", currentTerm);
response.setBoolean("voteGranted", voteGranted);
response.setInteger("senderId", process.getProcessNum());
// Send directly to candidate
response.setInteger("receiverNum", candidateId);
sendMessage(response);
}
private void handleVoteResponse(VSMessage message, int senderId) {
if (currentState != State.CANDIDATE) {
return;
}
boolean voteGranted = message.getBoolean("voteGranted");
if (voteGranted) {
votesReceived.add(senderId);
raftLog("Received vote from " + senderId + " (total: " + votesReceived.size() + ")");
// Check if we have majority
int majority = (getNumProcesses() / 2) + 1;
if (votesReceived.size() >= majority) {
becomeLeader();
}
}
}
private void becomeLeader() {
currentState = State.LEADER;
currentLeader = process.getProcessNum();
raftLog("Became LEADER for term " + currentTerm);
// Initialize leader state
nextIndex.clear();
matchIndex.clear();
for (int i = 0; i < getNumProcesses(); i++) {
if (i != process.getProcessNum()) {
nextIndex.put(i, log.size());
matchIndex.put(i, 0);
}
}
// Send initial heartbeats immediately
sendHeartbeats();
// Schedule regular heartbeats
removeSchedules();
scheduleAt(process.getTime() + HEARTBEAT_INTERVAL);
// Highlight the leader visually
if (process instanceof VSInternalProcess) {
((VSInternalProcess) process).highlightOn();
}
}
private void becomeFollower() {
currentState = State.FOLLOWER;
raftLog("Became FOLLOWER for term " + currentTerm);
// Remove leader highlighting
if (process instanceof VSInternalProcess) {
((VSInternalProcess) process).highlightOff();
}
// Reset election timeout
resetElectionTimeout();
scheduleElectionTimeout();
}
private void sendHeartbeats() {
for (int i = 0; i < getNumProcesses(); i++) {
if (i != process.getProcessNum()) {
sendAppendEntries(i);
}
}
}
private void sendAppendEntries(int followerId) {
int nextIdx = nextIndex.getOrDefault(followerId, 1);
int prevLogIndex = nextIdx - 1;
int prevLogTerm = prevLogIndex >= 0 && prevLogIndex < log.size() ? log.get(prevLogIndex).term : 0;
VSMessage appendEntries = new VSMessage();
appendEntries.setString("type", MSG_APPEND_ENTRIES);
appendEntries.setInteger("term", currentTerm);
appendEntries.setInteger("leaderId", process.getProcessNum());
appendEntries.setInteger("prevLogIndex", prevLogIndex);
appendEntries.setInteger("prevLogTerm", prevLogTerm);
appendEntries.setInteger("leaderCommit", commitIndex);
// Include log entries if needed
List entries = new ArrayList<>();
for (int i = nextIdx; i < log.size(); i++) {
entries.add(log.get(i));
}
// For simplicity, we'll send entry count and details separately
appendEntries.setInteger("entryCount", entries.size());
for (int i = 0; i < entries.size(); i++) {
LogEntry entry = entries.get(i);
appendEntries.setInteger("entry_" + i + "_term", entry.term);
appendEntries.setString("entry_" + i + "_cmd", entry.command);
appendEntries.setLong("entry_" + i + "_time", entry.timestamp);
}
appendEntries.setInteger("receiverNum", followerId);
sendMessage(appendEntries);
}
private void handleAppendEntries(VSMessage message, int leaderId) {
int term = message.getInteger("term");
int prevLogIndex = message.getInteger("prevLogIndex");
int prevLogTerm = message.getInteger("prevLogTerm");
int leaderCommit = message.getInteger("leaderCommit");
// Reset election timeout when we hear from leader
resetElectionTimeout();
lastHeartbeat = process.getTime();
currentLeader = leaderId;
boolean success = false;
// Check if log matches at prevLogIndex
if (prevLogIndex == 0 ||
(prevLogIndex < log.size() && log.get(prevLogIndex).term == prevLogTerm)) {
success = true;
// Remove conflicting entries
if (prevLogIndex + 1 < log.size()) {
log.subList(prevLogIndex + 1, log.size()).clear();
}
// Append new entries
int entryCount = message.getInteger("entryCount");
for (int i = 0; i < entryCount; i++) {
int entryTerm = message.getInteger("entry_" + i + "_term");
String entryCmd = message.getString("entry_" + i + "_cmd");
long entryTime = message.getLong("entry_" + i + "_time");
log.add(new LogEntry(entryTerm, entryCmd, entryTime));
raftLog("Appended log entry: " + entryCmd);
}
// Update commit index
if (leaderCommit > commitIndex) {
commitIndex = Math.min(leaderCommit, log.size() - 1);
applyStateMachine();
}
}
// Send response
VSMessage response = new VSMessage();
response.setString("type", MSG_APPEND_RESPONSE);
response.setInteger("term", currentTerm);
response.setBoolean("success", success);
response.setInteger("senderId", process.getProcessNum());
response.setInteger("matchIndex", log.size() - 1);
response.setInteger("receiverNum", leaderId);
sendMessage(response);
}
private void handleAppendResponse(VSMessage message, int followerId) {
if (currentState != State.LEADER) {
return;
}
boolean success = message.getBoolean("success");
int matchIdx = message.getInteger("matchIndex");
if (success) {
matchIndex.put(followerId, matchIdx);
nextIndex.put(followerId, matchIdx + 1);
// Check if we can advance commit index
updateCommitIndex();
} else {
// Decrement nextIndex and retry
int next = nextIndex.getOrDefault(followerId, 1);
if (next > 1) {
nextIndex.put(followerId, next - 1);
}
}
}
private void handleClientRequest(VSMessage message, int clientId) {
if (currentState != State.LEADER) {
// Redirect to leader or reject
VSMessage response = new VSMessage();
response.setString("type", "CLIENT_RESPONSE");
response.setBoolean("success", false);
response.setString("result", "Not leader. Leader is: " + currentLeader);
response.setInteger("receiverNum", clientId);
sendMessage(response);
return;
}
// Append to log
String command = message.getString("command");
LogEntry entry = new LogEntry(currentTerm, command, process.getTime());
log.add(entry);
raftLog("Leader received client request: " + command);
// Will be committed when replicated to majority
// For now, send optimistic response
VSMessage response = new VSMessage();
response.setString("type", "CLIENT_RESPONSE");
response.setBoolean("success", true);
response.setString("result", "Command logged: " + command);
response.setInteger("receiverNum", clientId);
sendMessage(response);
}
// --- Helper Methods ---
private boolean isLogUpToDate(int lastLogIndex, int lastLogTerm) {
int ourLastIndex = log.size() - 1;
int ourLastTerm = log.get(ourLastIndex).term;
return lastLogTerm > ourLastTerm ||
(lastLogTerm == ourLastTerm && lastLogIndex >= ourLastIndex);
}
private void resetElectionTimeout() {
if (process != null) {
long timeout = ELECTION_TIMEOUT_MIN +
(long)(Math.random() * (ELECTION_TIMEOUT_MAX - ELECTION_TIMEOUT_MIN));
electionTimeout = process.getTime() + timeout;
}
}
private void scheduleElectionTimeout() {
removeSchedules();
scheduleAt(electionTimeout);
}
private void updateCommitIndex() {
// Find the highest index that has been replicated to majority
for (int n = log.size() - 1; n > commitIndex; n--) {
if (log.get(n).term == currentTerm) {
int replicatedCount = 1; // Leader has it
for (int matchIdx : matchIndex.values()) {
if (matchIdx >= n) {
replicatedCount++;
}
}
if (replicatedCount > getNumProcesses() / 2) {
commitIndex = n;
applyStateMachine();
break;
}
}
}
}
private void applyStateMachine() {
while (lastApplied < commitIndex) {
lastApplied++;
LogEntry entry = log.get(lastApplied);
raftLog("Applied to state machine: " + entry.command);
}
}
private void raftLog(String message) {
String stateStr = currentState != null ? currentState.toString() : "CLIENT";
String prefix = String.format("[%s T:%d N:%d] ",
stateStr, currentTerm, process.getProcessNum());
process.log(prefix + message);
}
@Override
public String toString() {
return super.toString() + " - Raft Consensus";
}
}