Skip to content

Commit

Permalink
refactor: change to RefBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
areyouok committed Jun 17, 2024
1 parent a6aaa99 commit 798f046
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,29 @@ public KvSnapshot(SnapshotInfo si, Supplier<KvStatus> statusSupplier, Consumer<S
}

@Override
public FiberFuture<Void> readNext(ByteBuffer buffer) {
public FiberFuture<Integer> readNext(ByteBuffer buffer) {
FiberGroup fiberGroup = FiberGroup.currentGroup();
KvStatus current = statusSupplier.get();
if (current.status != KvStatus.RUNNING || current.epoch != epoch) {
return FiberFuture.failedFuture(fiberGroup, new RaftException("the snapshot is expired"));
}

int startPos = buffer.position();
while (true) {
if (currentValue == null) {
nextValue();
}
if (currentValue == null) {
// no more data
buffer.flip();
return FiberFuture.completedFuture(fiberGroup, null);
return FiberFuture.completedFuture(fiberGroup, buffer.position() - startPos);
}

if (encodeStatus.writeToBuffer(buffer)) {
encodeStatus.reset();
currentValue = null;
} else {
// buffer is full
buffer.flip();
return FiberFuture.completedFuture(fiberGroup, null);
return FiberFuture.completedFuture(fiberGroup, buffer.position() - startPos);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
*/
package com.github.dtprj.dongting.raft.impl;

import com.github.dtprj.dongting.buf.ByteBufferPool;
import com.github.dtprj.dongting.buf.RefBuffer;
import com.github.dtprj.dongting.buf.RefBufferFactory;
import com.github.dtprj.dongting.codec.PbNoCopyDecoder;
import com.github.dtprj.dongting.common.DtTime;
import com.github.dtprj.dongting.common.DtUtil;
Expand Down Expand Up @@ -50,12 +51,10 @@
import com.github.dtprj.dongting.raft.store.RaftLog;
import com.github.dtprj.dongting.raft.store.StatusManager;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
Expand Down Expand Up @@ -571,7 +570,7 @@ class LeaderInstallFrame extends AbstractLeaderRepFrame {
private final RaftServerConfig serverConfig;
private final NioClient client;
private final ReplicateManager replicateManager;
private final ByteBufferPool heapPool;
private final RefBufferFactory heapPool;

private Snapshot snapshot;
private long nextPosAfterInstallFinish;
Expand All @@ -585,7 +584,7 @@ public LeaderInstallFrame(ReplicateManager replicateManager, RaftMember member)
this.serverConfig = replicateManager.serverConfig;
this.client = replicateManager.client;
this.replicateManager = replicateManager;
this.heapPool = groupConfig.getFiberGroup().getThread().getHeapPool().getPool();
this.heapPool = groupConfig.getFiberGroup().getThread().getHeapPool();
}

@Override
Expand Down Expand Up @@ -634,17 +633,18 @@ private FrameCallResult afterFirstReqFinished(Void unused) {
if (shouldStopReplicate()) {
return Fiber.frameReturn();
}
Supplier<ByteBuffer> bufferCreator = () -> heapPool.borrow(groupConfig.getReplicateSnapshotBufferSize());
Consumer<ByteBuffer> releaser = heapPool::release;
Supplier<RefBuffer> bufferCreator = () -> heapPool.create(groupConfig.getReplicateSnapshotBufferSize());

int readConcurrency = groupConfig.getSnapshotConcurrency();
int writeConcurrency = groupConfig.getReplicateSnapshotConcurrency();
SnapshotReader r = new SnapshotReader(snapshot, readConcurrency, writeConcurrency, this::readerCallback,
this::shouldStopReplicate, bufferCreator, releaser);
this::shouldStopReplicate, bufferCreator);
return Fiber.call(r, this::afterReaderFinish);
}

private FiberFuture<Void> readerCallback(ByteBuffer buf) {
private FiberFuture<Void> readerCallback(RefBuffer buf, Integer readBytes) {
buf.getBuffer().clear();
buf.getBuffer().limit(readBytes);
return sendInstallSnapshotReq(buf, false, false);
}

Expand All @@ -653,7 +653,7 @@ private FrameCallResult afterReaderFinish(Void unused) {
.await(this::justReturn);
}

private FiberFuture<Void> sendInstallSnapshotReq(ByteBuffer data, boolean start, boolean finish) {
private FiberFuture<Void> sendInstallSnapshotReq(RefBuffer data, boolean start, boolean finish) {
SnapshotInfo si = snapshot.getSnapshotInfo();
InstallSnapshotReq req = new InstallSnapshotReq();
req.groupId = groupId;
Expand All @@ -675,15 +675,14 @@ private FiberFuture<Void> sendInstallSnapshotReq(ByteBuffer data, boolean start,
req.nextWritePos = nextPosAfterInstallFinish;
}
req.data = data;
req.pool = heapPool;

// data buffer released in WriteFrame
InstallSnapshotReq.InstallReqWriteFrame wf = new InstallSnapshotReq.InstallReqWriteFrame(req);
wf.setCommand(Commands.RAFT_INSTALL_SNAPSHOT);
DtTime timeout = new DtTime(serverConfig.getRpcTimeout(), TimeUnit.MILLISECONDS);
CompletableFuture<ReadFrame<AppendRespCallback>> future = client.sendRequest(
member.getNode().getPeer(), wf, APPEND_RESP_DECODER, timeout);
int bytes = data == null ? 0 : data.remaining();
int bytes = data == null ? 0 : data.getBuffer().remaining();
snapshotOffset += bytes;
FiberFuture<Void> f = getFiberGroup().newFuture("install-" + groupId + "-" + req.offset);
future.whenCompleteAsync((rf, ex) -> afterInstallRpc(rf, ex, req, f), getFiberGroup().getExecutor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.github.dtprj.dongting.raft.impl;

import com.github.dtprj.dongting.buf.RefBuffer;
import com.github.dtprj.dongting.common.Pair;
import com.github.dtprj.dongting.fiber.Fiber;
import com.github.dtprj.dongting.fiber.FiberCondition;
Expand All @@ -27,10 +28,8 @@
import com.github.dtprj.dongting.log.DtLogs;
import com.github.dtprj.dongting.raft.sm.Snapshot;

import java.nio.ByteBuffer;
import java.util.LinkedList;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.function.Supplier;

/**
Expand All @@ -43,32 +42,29 @@ public class SnapshotReader extends FiberFrame<Void> {
private final Snapshot snapshot;
private final int maxReadConcurrency;
private final int maxWriteConcurrency;
private final Function<ByteBuffer, FiberFuture<Void>> callback;
private final BiFunction<RefBuffer, Integer, FiberFuture<Void>> callback;
private final Supplier<Boolean> cancel;
private final Supplier<ByteBuffer> bufferCreator;
private final Consumer<ByteBuffer> bufferReleaser;
private final Supplier<RefBuffer> bufferCreator;

private final FiberCondition cond;

private final LinkedList<Pair<ByteBuffer, FiberFuture<Void>>> readList = new LinkedList<>();
private final LinkedList<Pair<RefBuffer, FiberFuture<Integer>>> readList = new LinkedList<>();
private final LinkedList<FiberFuture<Void>> writeList = new LinkedList<>();

private Throwable firstEx;
private boolean readFinish;

// the callback should release the buffer
public SnapshotReader(Snapshot snapshot, int maxReadConcurrency, int maxWriteConcurrency,
Function<ByteBuffer, FiberFuture<Void>> callback,
BiFunction<RefBuffer, Integer, FiberFuture<Void>> callback,
Supplier<Boolean> cancel,
Supplier<ByteBuffer> bufferCreator,
Consumer<ByteBuffer> bufferReleaser) {
Supplier<RefBuffer> bufferCreator) {
this.snapshot = snapshot;
this.maxReadConcurrency = maxReadConcurrency;
this.maxWriteConcurrency = maxWriteConcurrency;
this.callback = callback;
this.cancel = cancel;
this.bufferCreator = bufferCreator;
this.bufferReleaser = bufferReleaser;
this.cond = FiberGroup.currentGroup().newCondition("snapshotReaderCond");
}

Expand All @@ -84,9 +80,9 @@ public FrameCallResult execute(Void input) throws Throwable {
if (!readFinish && firstEx == null && !cancel.get()) {
if (readList.size() < maxReadConcurrency) {
// fire read task
ByteBuffer buf = bufferCreator.get();
RefBuffer buf = bufferCreator.get();
addNewReadTask = true;
FiberFuture<Void> future = snapshot.readNext(buf);
FiberFuture<Integer> future = snapshot.readNext(buf.getBuffer());
readList.add(new Pair<>(buf, future));
future.registerCallback((v, ex) -> cond.signal());
}
Expand All @@ -113,12 +109,12 @@ public FrameCallResult execute(Void input) throws Throwable {
}

private void processReadResult() {
Pair<ByteBuffer, FiberFuture<Void>> pair = readList.peekFirst();
Pair<RefBuffer, FiberFuture<Integer>> pair = readList.peekFirst();
if (pair == null) {
return;
}
FiberFuture<Void> f = pair.getRight();
ByteBuffer buf = pair.getLeft();
FiberFuture<Integer> f = pair.getRight();
RefBuffer buf = pair.getLeft();
if (!f.isDone()) {
// header in list not finished, wait for next time
return;
Expand All @@ -127,16 +123,16 @@ private void processReadResult() {
// only record the first exception
firstEx = f.getEx();
}
if (!buf.hasRemaining()) {
if (f.getResult() != null && f.getResult() == 0) {
readFinish = true;
}
if (cancel.get() || firstEx != null || readFinish) {
readList.removeFirst();
bufferReleaser.accept(buf);
buf.release();
} else if (writeList.size() < maxWriteConcurrency) {
readList.removeFirst();
try {
FiberFuture<Void> writeFuture = callback.apply(buf);
FiberFuture<Void> writeFuture = callback.apply(buf, f.getResult());
writeList.add(writeFuture);
writeFuture.registerCallback((v, ex) -> cond.signal());
} catch (Throwable e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import com.github.dtprj.dongting.raft.server.ReqInfo;
import com.github.dtprj.dongting.raft.sm.RaftCodecFactory;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -468,8 +469,9 @@ private FrameCallResult applyConfigChange(Void unused) {

private FrameCallResult doInstall(RaftStatusImpl raftStatus, InstallSnapshotReq req) {
boolean finish = req.done;
ByteBuffer buf = req.data == null ? null : req.data.getBuffer();
FiberFuture<Void> f = gc.getStateMachine().installSnapshot(req.lastIncludedIndex,
req.lastIncludedTerm, req.offset, finish, req.data);
req.lastIncludedTerm, req.offset, finish, buf);
if (finish) {
return f.await(v -> finishInstall(req, raftStatus));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.github.dtprj.dongting.raft.rpc;

import com.github.dtprj.dongting.buf.ByteBufferPool;
import com.github.dtprj.dongting.buf.RefBuffer;
import com.github.dtprj.dongting.buf.RefBufferFactory;
import com.github.dtprj.dongting.codec.DecodeContext;
import com.github.dtprj.dongting.codec.EncodeContext;
Expand Down Expand Up @@ -61,14 +61,12 @@ public class InstallSnapshotReq {
public Set<Integer> preparedObservers;
public long lastConfigChangeIndex;

public ByteBuffer data;
public ByteBufferPool pool;
public RefBuffer data;

public void release() {
if (data != null && pool != null) {
pool.release(data);
public void release(){
if (data != null) {
data.release();
data = null;
pool = null;
}
}

Expand Down Expand Up @@ -152,12 +150,11 @@ public boolean readBytes(int index, ByteBuffer buf, int len, int currentPos) {
boolean end = buf.remaining() >= len - currentPos;
if (index == 15) {
if (currentPos == 0) {
result.data = heapPool.getPool().borrow(len);
result.pool = heapPool.getPool();
result.data = heapPool.create(len);
}
result.data.put(buf);
result.data.getBuffer().put(buf);
if (end) {
result.data.flip();
result.data.getBuffer().flip();
}
}
return true;
Expand Down Expand Up @@ -192,8 +189,9 @@ public InstallReqWriteFrame(InstallSnapshotReq req) {
x += calcFix32SetSize(12, req.preparedObservers);
x += PbUtil.accurateFix64Size(13, req.lastConfigChangeIndex);

if (req.data != null && req.data.hasRemaining()) {
this.bufferSize = req.data.remaining();
RefBuffer rb = req.data;
if (rb != null && rb.getBuffer().hasRemaining()) {
this.bufferSize = rb.getBuffer().remaining();
x += PbUtil.accurateLengthDelimitedSize(15, bufferSize);
} else {
this.bufferSize = 0;
Expand Down Expand Up @@ -245,8 +243,8 @@ protected boolean encodeBody(EncodeContext context, ByteBuffer dest) {
if (bufferSize == 0) {
return true;
}
dest.put(req.data);
return !req.data.hasRemaining();
dest.put(req.data.getBuffer());
return !req.data.getBuffer().hasRemaining();
}

private void writeSet(ByteBuffer buf, int index, Set<Integer> s) {
Expand Down
Loading

0 comments on commit 798f046

Please sign in to comment.