Skip to content

Commit

Permalink
retries
Browse files Browse the repository at this point in the history
  • Loading branch information
mattstam committed Jan 10, 2025
1 parent 25b9e03 commit 0a26a3f
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 198 deletions.
16 changes: 8 additions & 8 deletions crates/perf/workflow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ CUDA_WORKLOADS=(

# Define the list of network workloads.
NETWORK_WORKLOADS=(
# "fibonacci-17k"
# "ssz-withdrawals"
# "tendermint"
# "rsp-20526624"
# "rsa"
# "regex"
# "chess"
# "json"
"fibonacci-17k"
"ssz-withdrawals"
"tendermint"
"rsp-20526624"
"rsa"
"regex"
"chess"
"json"
# "blobstream-01j6z63fgafrc8jeh0k12gbtvw"
# "blobstream-01j6z95bdme9svevmfyc974bja"
# "blobstream-01j6z9ak0ke9srsppgywgke6fj"
Expand Down
251 changes: 168 additions & 83 deletions crates/sdk/src/network/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@ use alloy_primitives::B256;
use alloy_signer::SignerSync;
use alloy_signer_local::PrivateKeySigner;
use anyhow::{Context, Ok, Result};
use async_trait::async_trait;
use reqwest_middleware::ClientWithMiddleware as HttpClientWithMiddleware;
use serde::{de::DeserializeOwned, Serialize};
use sp1_core_machine::io::SP1Stdin;
use sp1_prover::{HashableKey, SP1VerifyingKey};
use tonic::{
transport::{channel::ClientTlsConfig, Channel},
Code,
};
use tonic::{transport::Channel, Code};

use super::grpc;
use super::retry::{self, RetryableRpc, DEFAULT_RETRY_TIMEOUT};
use super::utils::Signable;
use crate::network::proto::artifact::{
artifact_store_client::ArtifactStoreClient, ArtifactType, CreateArtifactRequest,
Expand All @@ -39,6 +38,34 @@ pub struct NetworkClient {
pub(crate) rpc_url: String,
}

#[async_trait]
impl RetryableRpc for NetworkClient {
/// Execute an operation with retries using default timeout.
async fn with_retry<'a, T, F, Fut>(&'a self, operation: F, operation_name: &str) -> Result<T>
where
F: Fn() -> Fut + Send + Sync + 'a,
Fut: std::future::Future<Output = Result<T>> + Send,
T: Send,
{
self.with_retry_timeout(operation, DEFAULT_RETRY_TIMEOUT, operation_name).await
}

/// Execute an operation with retries using the specified timeout.
async fn with_retry_timeout<'a, T, F, Fut>(
&'a self,
operation: F,
timeout: Duration,
operation_name: &str,
) -> Result<T>
where
F: Fn() -> Fut + Send + Sync + 'a,
Fut: std::future::Future<Output = Result<T>> + Send,
T: Send,
{
retry::retry_operation(operation, Some(timeout), operation_name).await
}
}

impl NetworkClient {
/// Creates a new [`NetworkClient`] with the given private key and rpc url.
pub fn new(private_key: impl Into<String>, rpc_url: impl Into<String>) -> Self {
Expand All @@ -53,10 +80,17 @@ impl NetworkClient {

/// Get the latest nonce for this account's address.
pub async fn get_nonce(&self) -> Result<u64> {
let mut rpc = self.prover_network_client().await?;
let res =
rpc.get_nonce(GetNonceRequest { address: self.signer.address().to_vec() }).await?;
Ok(res.into_inner().nonce)
self.with_retry(
|| async {
let mut rpc = self.prover_network_client().await?;
let res = rpc
.get_nonce(GetNonceRequest { address: self.signer.address().to_vec() })
.await?;
Ok(res.into_inner().nonce)
},
"getting nonce",
)
.await
}

/// Get the verifying key hash from a verifying key.
Expand Down Expand Up @@ -89,12 +123,18 @@ impl NetworkClient {
/// # Details
/// Returns `None` if the program does not exist.
pub async fn get_program(&self, vk_hash: B256) -> Result<Option<GetProgramResponse>> {
let mut rpc = self.prover_network_client().await?;
match rpc.get_program(GetProgramRequest { vk_hash: vk_hash.to_vec() }).await {
StdOk(response) => Ok(Some(response.into_inner())),
Err(status) if status.code() == Code::NotFound => Ok(None),
Err(e) => Err(e.into()),
}
self.with_retry(
|| async {
let mut rpc = self.prover_network_client().await?;
match rpc.get_program(GetProgramRequest { vk_hash: vk_hash.to_vec() }).await {
StdOk(response) => Ok(Some(response.into_inner())),
Err(status) if status.code() == Code::NotFound => Ok(None),
Err(e) => Err(e.into()),
}
},
"getting program",
)
.await
}

/// Creates a new program on the network.
Expand All @@ -113,23 +153,29 @@ impl NetworkClient {
let vk_encoded = bincode::serialize(&vk)?;

// Send the request.
let mut rpc = self.prover_network_client().await?;
let nonce = self.get_nonce().await?;
let request_body = CreateProgramRequestBody {
nonce,
vk_hash: vk_hash.to_vec(),
vk: vk_encoded,
program_uri,
};

Ok(rpc
.create_program(CreateProgramRequest {
format: MessageFormat::Binary.into(),
signature: request_body.sign(&self.signer).into(),
body: Some(request_body),
})
.await?
.into_inner())
self.with_retry(
|| async {
let mut rpc = self.prover_network_client().await?;
let nonce = self.get_nonce().await?;
let request_body = CreateProgramRequestBody {
nonce,
vk_hash: vk_hash.to_vec(),
vk: vk_encoded.clone(),
program_uri: program_uri.clone(),
};

Ok(rpc
.create_program(CreateProgramRequest {
format: MessageFormat::Binary.into(),
signature: request_body.sign(&self.signer).into(),
body: Some(request_body),
})
.await?
.into_inner())
},
"creating program",
)
.await
}

/// Get all the proof requests that meet the filter criteria.
Expand All @@ -149,25 +195,37 @@ impl NetworkClient {
page: Option<u32>,
mode: Option<i32>,
) -> Result<GetFilteredProofRequestsResponse> {
let mut rpc = self.prover_network_client().await?;
let res = rpc
.get_filtered_proof_requests(GetFilteredProofRequestsRequest {
version,
fulfillment_status,
execution_status,
minimum_deadline,
vk_hash,
requester,
fulfiller,
from,
to,
limit,
page,
mode,
})
.await?
.into_inner();
Ok(res)
self.with_retry(
|| {
let version = version.clone();
let vk_hash = vk_hash.clone();
let requester = requester.clone();
let fulfiller = fulfiller.clone();

async move {
let mut rpc = self.prover_network_client().await?;
Ok(rpc
.get_filtered_proof_requests(GetFilteredProofRequestsRequest {
version,
fulfillment_status,
execution_status,
minimum_deadline,
vk_hash,
requester,
fulfiller,
from,
to,
limit,
page,
mode,
})
.await?
.into_inner())
}
},
"getting filtered proof requests",
)
.await
}

/// Get the status of a given proof.
Expand All @@ -177,14 +235,24 @@ impl NetworkClient {
pub async fn get_proof_request_status<P: DeserializeOwned>(
&self,
request_id: B256,
timeout: Option<Duration>,
) -> Result<(GetProofRequestStatusResponse, Option<P>)> {
let mut rpc = self.prover_network_client().await?;
let res = rpc
.get_proof_request_status(GetProofRequestStatusRequest {
request_id: request_id.to_vec(),
})
.await?
.into_inner();
// Get the status.
let res = self
.with_retry_timeout(
|| async {
let mut rpc = self.prover_network_client().await?;
Ok(rpc
.get_proof_request_status(GetProofRequestStatusRequest {
request_id: request_id.to_vec(),
})
.await?
.into_inner())
},
timeout.unwrap_or(DEFAULT_RETRY_TIMEOUT),
"getting proof request status",
)
.await?;

let status = FulfillmentStatus::try_from(res.fulfillment_status)?;
let proof = match status {
Expand Down Expand Up @@ -264,20 +332,11 @@ impl NetworkClient {
}

pub(crate) async fn artifact_store_client(&self) -> Result<ArtifactStoreClient<Channel>> {
let rpc_url = self.rpc_url.clone();
let mut endpoint = Channel::from_shared(rpc_url.clone())?;

// Check if the URL scheme is HTTPS and configure TLS.
if rpc_url.starts_with("https://") {
let tls_config = ClientTlsConfig::new().with_enabled_roots();
endpoint = endpoint.tls_config(tls_config)?;
}

let channel = endpoint.connect().await?;
Ok(ArtifactStoreClient::new(channel.clone()))
let channel = grpc::configure_endpoint(&self.rpc_url)?.connect().await?;
Ok(ArtifactStoreClient::new(channel))
}

pub(crate) async fn create_artifact_with_content<T: Serialize>(
pub(crate) async fn create_artifact_with_content<T: Serialize + Send + Sync>(
&self,
store: &mut ArtifactStoreClient<Channel>,
artifact_type: ArtifactType,
Expand All @@ -288,29 +347,55 @@ impl NetworkClient {
artifact_type: artifact_type.into(),
signature: signature.as_bytes().to_vec(),
};

// Create the artifact.
let response = store.create_artifact(request).await?.into_inner();

let presigned_url = response.artifact_presigned_url;
let uri = response.artifact_uri;

let response =
self.http.put(&presigned_url).body(bincode::serialize::<T>(item)?).send().await?;

if !response.status().is_success() {
log::debug!("Artifact upload failed with status: {}", response.status());
}
assert!(response.status().is_success());
// Upload the content.
self.with_retry(
|| async {
let response = self
.http
.put(&presigned_url)
.body(bincode::serialize::<T>(item)?)
.send()
.await?;

if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to upload artifact: HTTP {}",
response.status()
));
}
Ok(())
},
"uploading artifact content",
)
.await?;

Ok(uri)
}

pub(crate) async fn download_artifact(&self, uri: &str) -> Result<Vec<u8>> {
let response = self.http.get(uri).send().await.context("Failed to download from URI")?;

if !response.status().is_success() {
return Err(anyhow::anyhow!("Failed to download artifact: HTTP {}", response.status()));
}

Ok(response.bytes().await.context("Failed to read response body")?.to_vec())
self.with_retry(
|| async {
let response =
self.http.get(uri).send().await.context("Failed to download from URI")?;

if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to download artifact: HTTP {}",
response.status()
));
}

Ok(response.bytes().await.context("Failed to read response body")?.to_vec())
},
"downloading artifact",
)
.await
}
}
1 change: 1 addition & 0 deletions crates/sdk/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod builder;
mod error;
mod grpc;
pub mod prove;
mod retry;
pub mod utils;

pub use error::*;
Expand Down
Loading

0 comments on commit 0a26a3f

Please sign in to comment.