Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1757] Add retry when sending RPC to LifecycleManager #3008

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 60 additions & 63 deletions client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.*;
zaynt4606 marked this conversation as resolved.
Show resolved Hide resolved

import scala.Tuple2;
import scala.reflect.ClassTag$;
Expand Down Expand Up @@ -81,6 +78,7 @@ public class ShuffleClientImpl extends ShuffleClient {

private final int registerShuffleMaxRetries;
private final long registerShuffleRetryWaitMs;
private final int callLifecycleManagerMaxRetry;
private final int maxReviveTimes;
private final boolean testRetryRevive;
private final int pushBufferMaxSize;
Expand Down Expand Up @@ -179,6 +177,7 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u
this.userIdentifier = userIdentifier;
registerShuffleMaxRetries = conf.clientRegisterShuffleMaxRetry();
registerShuffleRetryWaitMs = conf.clientRegisterShuffleRetryWaitMs();
callLifecycleManagerMaxRetry = conf.clientCallLifecycleManagerMaxRetry();
maxReviveTimes = conf.clientPushMaxReviveTimes();
testRetryRevive = conf.testRetryRevive();
pushBufferMaxSize = conf.clientPushBufferMaxSize();
Expand Down Expand Up @@ -534,6 +533,7 @@ private ConcurrentHashMap<Integer, PartitionLocation> registerShuffle(
lifecycleManagerRef.askSync(
RegisterShuffle$.MODULE$.apply(shuffleId, numMappers, numPartitions),
conf.clientRpcRegisterShuffleAskTimeout(),
callLifecycleManagerMaxRetry,
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
}

Expand Down Expand Up @@ -1700,13 +1700,12 @@ private void mapEndInternal(
throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
PushState pushState = getPushState(mapKey);

zaynt4606 marked this conversation as resolved.
Show resolved Hide resolved
try {
limitZeroInFlight(mapKey, pushState);

zaynt4606 marked this conversation as resolved.
Show resolved Hide resolved
MapperEndResponse response =
lifecycleManagerRef.askSync(
new MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId),
callLifecycleManagerMaxRetry,
ClassTag$.MODULE$.apply(MapperEndResponse.class));
if (response.status() != StatusCode.SUCCESS) {
throw new CelebornIOException("MapperEnd failed! StatusCode: " + response.status());
Expand Down Expand Up @@ -1741,65 +1740,60 @@ public boolean cleanupShuffle(int shuffleId) {

protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
int shuffleId, boolean isSegmentGranularityVisible) {
{
long getReducerFileGroupStartTime = System.nanoTime();
String exceptionMsg = null;
try {
if (lifecycleManagerRef == null) {
exceptionMsg = "Driver endpoint is null!";
logger.warn(exceptionMsg);
} else {
GetReducerFileGroup getReducerFileGroup =
new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible);

GetReducerFileGroupResponse response =
lifecycleManagerRef.askSync(
getReducerFileGroup,
conf.clientRpcGetReducerFileGroupAskTimeout(),
ClassTag$.MODULE$.apply(GetReducerFileGroupResponse.class));
long getReducerFileGroupStartTime = System.nanoTime();
String exceptionMsg = null;
if (lifecycleManagerRef == null) {
exceptionMsg = "Driver endpoint is null!";
logger.warn(exceptionMsg);
return Tuple2.apply(null, exceptionMsg);
}
try {
GetReducerFileGroup getReducerFileGroup =
new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible);

switch (response.status()) {
case SUCCESS:
logger.info(
"Shuffle {} request reducer file group success using {} ms, result partition size {}.",
shuffleId,
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime),
response.fileGroup().size());
return Tuple2.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds()),
null);
case SHUFFLE_NOT_REGISTERED:
logger.warn(
"Request {} return {} for {}.",
getReducerFileGroup,
response.status(),
shuffleId);
// return empty result
return Tuple2.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds()),
null);
case STAGE_END_TIME_OUT:
case SHUFFLE_DATA_LOST:
exceptionMsg =
String.format(
"Request %s return %s for %s.",
getReducerFileGroup, response.status(), shuffleId);
logger.warn(exceptionMsg);
break;
default: // fall out
}
}
} catch (Exception e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
logger.error("Exception raised while call GetReducerFileGroup for {}.", shuffleId, e);
exceptionMsg = e.getMessage();
GetReducerFileGroupResponse response =
lifecycleManagerRef.askSync(
getReducerFileGroup,
conf.clientRpcGetReducerFileGroupAskTimeout(),
callLifecycleManagerMaxRetry,
ClassTag$.MODULE$.apply(GetReducerFileGroupResponse.class));
switch (response.status()) {
case SUCCESS:
logger.info(
"Shuffle {} request reducer file group success using {} ms, result partition size {}.",
shuffleId,
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime),
response.fileGroup().size());
return Tuple2.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds()),
null);
case SHUFFLE_NOT_REGISTERED:
logger.warn(
"Request {} return {} for {}.", getReducerFileGroup, response.status(), shuffleId);
// return empty result
return Tuple2.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds()),
null);
case STAGE_END_TIME_OUT:
case SHUFFLE_DATA_LOST:
exceptionMsg =
String.format(
"Request %s return %s for %s.",
getReducerFileGroup, response.status(), shuffleId);
logger.warn(exceptionMsg);
break;
default: // fall out
}
return Tuple2.apply(null, exceptionMsg);
} catch (Exception e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
logger.error("Exception raised while call GetReducerFileGroup for {}.", shuffleId, e);
exceptionMsg = e.getMessage();
}
return Tuple2.apply(null, exceptionMsg);
}

@Override
Expand Down Expand Up @@ -1929,7 +1923,10 @@ public void shutdown() {
public void setupLifecycleManagerRef(String host, int port) {
logger.info("setupLifecycleManagerRef: host = {}, port = {}", host, port);
lifecycleManagerRef =
rpcEnv.setupEndpointRef(new RpcAddress(host, port), RpcNameConstants.LIFECYCLE_MANAGER_EP);
rpcEnv.setupEndpointRef(
new RpcAddress(host, port),
RpcNameConstants.LIFECYCLE_MANAGER_EP,
callLifecycleManagerMaxRetry);
initDataClientFactoryIfNeeded();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ private CelebornConf setupEnv(
RegisterShuffleResponse$.MODULE$.apply(
statusCode, new PartitionLocation[] {primaryLocation}));

when(endpointRef.askSync(any(), any(), any(Integer.class), any()))
.thenAnswer(
t ->
RegisterShuffleResponse$.MODULE$.apply(
statusCode, new PartitionLocation[] {primaryLocation}));

shuffleClient.setupLifecycleManagerRef(endpointRef);

ChannelFuture mockedFuture =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
new RpcTimeout(get(RPC_LOOKUP_TIMEOUT).milli, RPC_LOOKUP_TIMEOUT.key)
def rpcAskTimeout: RpcTimeout =
new RpcTimeout(get(RPC_ASK_TIMEOUT).milli, RPC_ASK_TIMEOUT.key)
def rpcTimeoutRetryWaitMs: Long = get(RPC_TIMEOUT_RETRY_WAIT)
def rpcInMemoryBoundedInboxCapacity(): Int = {
get(RPC_INBOX_CAPACITY)
}
Expand Down Expand Up @@ -901,6 +902,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def clientCloseIdleConnections: Boolean = get(CLIENT_CLOSE_IDLE_CONNECTIONS)
def clientRegisterShuffleMaxRetry: Int = get(CLIENT_REGISTER_SHUFFLE_MAX_RETRIES)
def clientRegisterShuffleRetryWaitMs: Long = get(CLIENT_REGISTER_SHUFFLE_RETRY_WAIT)
def clientCallLifecycleManagerMaxRetry: Int = get(CLIENT_CALL_LIFECYCLEMANAGER_MAX_RETRIES)
def clientReserveSlotsRackAwareEnabled: Boolean = get(CLIENT_RESERVE_SLOTS_RACKAWARE_ENABLED)
def clientReserveSlotsMaxRetries: Int = get(CLIENT_RESERVE_SLOTS_MAX_RETRIES)
def clientReserveSlotsRetryWait: Long = get(CLIENT_RESERVE_SLOTS_RETRY_WAIT)
Expand Down Expand Up @@ -4884,6 +4886,23 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("3s")

val RPC_TIMEOUT_RETRY_WAIT: ConfigEntry[Long] =
buildConf("celeborn.rpc.timeoutRetryWait")
zaynt4606 marked this conversation as resolved.
Show resolved Hide resolved
.categories("network")
.version("0.6.0")
.doc("Wait time before next retry if RpcTimeoutException.")
zaynt4606 marked this conversation as resolved.
Show resolved Hide resolved
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("1s")

val CLIENT_CALL_LIFECYCLEMANAGER_MAX_RETRIES: ConfigEntry[Int] =
buildConf("celeborn.client.callLifecycleManager.maxRetries")
.withAlternative("celeborn.callLifecycleManager.maxRetries")
zaynt4606 marked this conversation as resolved.
Show resolved Hide resolved
.categories("client")
.version("0.6.0")
.doc("Max retry times for client to reserve slots.")
zaynt4606 marked this conversation as resolved.
Show resolved Hide resolved
.intConf
.createWithDefault(3)

val CLIENT_RESERVE_SLOTS_MAX_RETRIES: ConfigEntry[Int] =
buildConf("celeborn.client.reserveSlots.maxRetries")
.withAlternative("celeborn.slots.reserve.maxRetries")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.celeborn.common.rpc

import java.util.Random
import java.util.concurrent.TimeUnit

import scala.concurrent.Future
import scala.reflect.ClassTag

Expand All @@ -30,6 +33,7 @@ abstract class RpcEndpointRef(conf: CelebornConf)
extends Serializable with Logging {

private[this] val defaultAskTimeout = conf.rpcAskTimeout
private[celeborn] val waitTimeBound = conf.rpcTimeoutRetryWaitMs.toInt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private[this] val defaultRetryWait


/**
* return the address for the [[RpcEndpointRef]]
Expand Down Expand Up @@ -88,4 +92,58 @@ abstract class RpcEndpointRef(conf: CelebornConf)
val future = ask[T](message, timeout)
timeout.awaitResult(future, address)
}

/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
* default timeout, retry if timeout, throw an exception if this still fails.
*
* Note: this is a blocking action which may cost a lot of time, so don't call it in a message
* loop of [[RpcEndpoint]].
*
* @param message the message to send
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
def askSync[T: ClassTag](message: Any, retryCount: Int): T =
askSync(message, defaultAskTimeout, retryCount)

/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
* specified timeout, retry if timeout, throw an exception if this still fails.
*
* Note: this is a blocking action which may cost a lot of time, so don't call it in a message
* loop of [[RpcEndpoint]].
*
* @param message the message to send
* @param timeout the timeout duration
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
def askSync[T: ClassTag](message: Any, timeout: RpcTimeout, retryCount: Int): T = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def askSync[T: ClassTag](message: Any, timeout: RpcTimeout, retryCount: Int, retryWait: Long)

var numRetries = retryCount
while (numRetries > 0) {
numRetries -= 1
try {
val future = ask[T](message, timeout)
return timeout.awaitResult(future, address)
} catch {
case e: RpcTimeoutException =>
if (numRetries > 0) {
val random = new Random
val retryWaitMs = random.nextInt(waitTimeBound)
try {
TimeUnit.MILLISECONDS.sleep(retryWaitMs)
} catch {
case _: InterruptedException =>
throw e
}
} else {
throw e
}
}
}
// should never be here
val future = ask[T](message, timeout)
timeout.awaitResult(future, address)
}
}
38 changes: 38 additions & 0 deletions common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.celeborn.common.rpc

import java.io.File
import java.util.Random
import java.util.concurrent.TimeUnit

import scala.concurrent.Future

Expand Down Expand Up @@ -104,6 +106,7 @@ object RpcEnv {
abstract class RpcEnv(config: RpcEnvConfig) {

private[celeborn] val defaultLookupTimeout = config.conf.rpcLookupTimeout
private[celeborn] val waitTimeBound = config.conf.rpcTimeoutRetryWaitMs.toInt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private[celeborn] val defaultRetryWait


/**
* Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement
Expand Down Expand Up @@ -142,6 +145,41 @@ abstract class RpcEnv(config: RpcEnvConfig) {
setupEndpointRefByAddr(RpcEndpointAddress(address, endpointName))
}

/**
* Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName` with timeout retry.
* This is a blocking action.
*/
def setupEndpointRef(
address: RpcAddress,
endpointName: String,
retryCount: Int): RpcEndpointRef = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def setupEndpointRef(
      address: RpcAddress,
      endpointName: String,
      retryCount: Int,
      retryWait: Long)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sory for delay.
This pr has been updated with two configraions CLIENT_RPC_RETRY_WAIT and RPC_RETRY_WAIT.

var numRetries = retryCount
while (numRetries > 0) {
numRetries -= 1
try {
return setupEndpointRefByAddr(RpcEndpointAddress(address, endpointName))
} catch {
case e: RpcTimeoutException =>
if (numRetries > 0) {
val random = new Random
val retryWaitMs = random.nextInt(waitTimeBound)
try {
TimeUnit.MILLISECONDS.sleep(retryWaitMs)
} catch {
case _: InterruptedException =>
throw e
}
} else {
throw e
}
case e: RpcEndpointNotFoundException =>
throw e
}
}
// should never be here
null
}

/**
* Stop [[RpcEndpoint]] specified by `endpoint`.
*/
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ license: |
| celeborn.client.application.heartbeatInterval | 10s | false | Interval for client to send heartbeat message to master. | 0.3.0 | celeborn.application.heartbeatInterval |
| celeborn.client.application.unregister.enabled | true | false | When true, Celeborn client will inform celeborn master the application is already shutdown during client exit, this allows the cluster to release resources immediately, resulting in resource savings. | 0.3.2 | |
| celeborn.client.application.uuidSuffix.enabled | false | false | Whether to add UUID suffix for application id for unique. When `true`, add UUID suffix for unique application id. Currently, this only applies to Spark and MR. | 0.6.0 | |
| celeborn.client.callLifecycleManager.maxRetries | 3 | false | Max retry times for client to reserve slots. | 0.6.0 | celeborn.callLifecycleManager.maxRetries |
| celeborn.client.chunk.prefetch.enabled | false | false | Whether to enable chunk prefetch when creating CelebornInputStream. | 0.6.0 | |
| celeborn.client.closeIdleConnections | true | false | Whether client will close idle connections. | 0.3.0 | |
| celeborn.client.commitFiles.ignoreExcludedWorker | false | false | When true, LifecycleManager will skip workers which are in the excluded list. | 0.3.0 | |
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/network.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ license: |
| celeborn.rpc.lookupTimeout | 30s | false | Timeout for RPC lookup operations. | 0.2.0 | |
| celeborn.rpc.slow.interval | &lt;undefined&gt; | false | min interval (ms) for RPC framework to log slow RPC | 0.6.0 | |
| celeborn.rpc.slow.threshold | 1s | false | threshold for RPC framework to log slow RPC | 0.6.0 | |
| celeborn.rpc.timeoutRetryWait | 1s | false | Wait time before next retry if RpcTimeoutException. | 0.6.0 | |
| celeborn.shuffle.io.maxChunksBeingTransferred | &lt;undefined&gt; | false | The max number of chunks allowed to be transferred at the same time on shuffle service. Note that new incoming connections will be closed when the max number is hit. The client will retry according to the shuffle retry configs (see `celeborn.<module>.io.maxRetries` and `celeborn.<module>.io.retryWait`), if those limits are reached the task will fail with fetch failure. | 0.2.0 | |
| celeborn.ssl.&lt;module&gt;.enabled | false | false | Enables SSL for securing wire traffic. | 0.5.0 | |
| celeborn.ssl.&lt;module&gt;.enabledAlgorithms | &lt;undefined&gt; | false | A comma-separated list of ciphers. The specified ciphers must be supported by JVM.<br/>The reference list of protocols can be found in the "JSSE Cipher Suite Names" section of the Java security guide. The list for Java 11, for example, can be found at [this page](https://docs.oracle.com/en/java/javase/11/docs/specs/security/standard-names.html#jsse-cipher-suite-names)<br/>Note: If not set, the default cipher suite for the JRE will be used | 0.5.0 | |
Expand Down
Loading