Skip to content

Commit

Permalink
test replacing ptr casts with normal cuda types
Browse files Browse the repository at this point in the history
  • Loading branch information
dkales committed Jul 25, 2024
1 parent a6124d7 commit c480df7
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 46 deletions.
36 changes: 12 additions & 24 deletions src/dot/share_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
task_monitor::TaskMonitor,
},
rng::chacha::ChaChaCudaRng,
threshold_ring::protocol::ChunkShare,
threshold_ring::protocol::{ChunkShare, ChunkShareView},
};
use axum::{routing::get, Router};
#[cfg(feature = "otp_encrypt")]
Expand All @@ -20,11 +20,12 @@ use cudarc::{
cublas::{result::gemm_ex, sys, CudaBlas},
driver::{
result::malloc_async, sys::CUdeviceptr, CudaFunction, CudaSlice, CudaStream, DevicePtr,
LaunchAsync, LaunchConfig,
DeviceSlice, LaunchAsync, LaunchConfig,

Check failure on line 23 in src/dot/share_db.rs

View workflow job for this annotation

GitHub Actions / doc

the name `DeviceSlice` is defined multiple times

Check failure on line 23 in src/dot/share_db.rs

View workflow job for this annotation

GitHub Actions / clippy

unused import: `DeviceSlice`

error: unused import: `DeviceSlice` --> src/dot/share_db.rs:23:9 | 23 | DeviceSlice, LaunchAsync, LaunchConfig, | ^^^^^^^^^^^

Check failure on line 23 in src/dot/share_db.rs

View workflow job for this annotation

GitHub Actions / clippy

the name `DeviceSlice` is defined multiple times

error[E0252]: the name `DeviceSlice` is defined multiple times --> src/dot/share_db.rs:23:9 | 18 | use cudarc::driver::{CudaView, DeviceSlice}; | ----------- previous import of the trait `DeviceSlice` here ... 23 | DeviceSlice, LaunchAsync, LaunchConfig, | ^^^^^^^^^^^-- | | | `DeviceSlice` reimported here | help: remove unnecessary import | = note: `DeviceSlice` must be defined only once in the type namespace of this module
},
nccl::{self, result, Comm, Id, NcclType},
nvrtc::compile_ptx,
};
use itertools::izip;
#[cfg(feature = "otp_encrypt")]
use itertools::Itertools;
use rayon::prelude::*;
Expand Down Expand Up @@ -847,28 +848,15 @@ impl ShareDB {
}
}

// TODO: this is very hacky
pub fn result_chunk_shares(&self, db_sizes: &[usize]) -> Vec<ChunkShare<u16>> {
let results_ptrs = self
.results
.iter()
.map(|x| *x.device_ptr())
.collect::<Vec<_>>();
let results_peer_ptrs = self
.results_peer
.iter()
.map(|x| *x.device_ptr())
.collect::<Vec<_>>();

device_ptrs_to_shares(
&results_ptrs,
&results_peer_ptrs,
&db_sizes
.iter()
.map(|e| e * self.query_length)
.collect::<Vec<_>>(),
self.device_manager.devices(),
)
pub fn result_chunk_shares(&self, db_sizes: &[usize]) -> Vec<ChunkShareView<u16>> {
izip!(self.results.iter(), self.results_peer.iter(), db_sizes)
.map(|(a, b, e)| ChunkShareView {
// SAFETY: we have ensured that the slices are of the correct length
a: unsafe { a.transmute(a.len() / 2).unwrap() }.slice(0..e * self.query_length),
// SAFETY: we have ensured that the slices are of the correct length
b: unsafe { b.transmute(b.len() / 2).unwrap() }.slice(0..e * self.query_length),
})
.collect()
}
}

Expand Down
31 changes: 15 additions & 16 deletions src/threshold_ring/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ use crate::{
use axum::{routing::get, Router};
use cudarc::{
driver::{
result::stream,
CudaDevice, CudaFunction, CudaSlice, CudaStream, CudaView, CudaViewMut, DevicePtr,
DeviceSlice, LaunchAsync, LaunchConfig,
result::stream, CudaDevice, CudaFunction, CudaSlice, CudaStream, CudaView, CudaViewMut,
DevicePtr, DeviceSlice, LaunchAsync, LaunchConfig,
},
nccl::{result, Comm, Id},
nvrtc::{self, Ptx},
Expand All @@ -34,9 +33,9 @@ pub struct ChunkShare<T> {
pub b: CudaSlice<T>,
}

pub struct ChunkShareView<'a, T> {
pub struct ChunkShareView<'a, 'b, T> {
pub a: CudaView<'a, T>,
pub b: CudaView<'a, T>,
pub b: CudaView<'b, T>,
}

impl<T> ChunkShare<T> {
Expand Down Expand Up @@ -88,12 +87,12 @@ impl<T> ChunkShare<T> {
}
}

pub struct ChunkShareViewMut<'a, T> {
pub struct ChunkShareViewMut<'a, 'b, T> {
pub a: CudaViewMut<'a, T>,
pub b: CudaViewMut<'a, T>,
pub b: CudaViewMut<'b, T>,
}

impl<'a, T> ChunkShareView<'a, T> {
impl<'a, 'b, T> ChunkShareView<'a, 'b, T> {
pub fn get_offset(&self, i: usize, chunk_size: usize) -> ChunkShareView<T> {
ChunkShareView {
a: self.a.slice(i * chunk_size..(i + 1) * chunk_size),
Expand Down Expand Up @@ -1616,7 +1615,7 @@ impl Circuits {

fn transpose_pack_u16_with_len(
&mut self,
inp: &[ChunkShare<u16>],
inp: &[ChunkShareView<u16>],
outp: &mut [ChunkShareView<u64>],
bitlen: usize,
streams: &[CudaStream],
Expand Down Expand Up @@ -1728,7 +1727,7 @@ impl Circuits {

fn lift_split(
&mut self,
inp: &[ChunkShare<u16>],
inp: &[ChunkShareView<u16>],
lifted: &mut [ChunkShareView<u32>],
inout1: &mut [ChunkShareView<u64>],
out2: &mut [ChunkShareView<u64>],
Expand Down Expand Up @@ -1773,7 +1772,7 @@ impl Circuits {
&mut self,
mask_lifted: &mut [ChunkShareView<u32>],
mask_correction: &[ChunkShareView<u16>],
code: &[ChunkShare<u16>],
code: &[ChunkShareView<u16>],
streams: &[CudaStream],
) {
assert_eq!(self.n_devices, mask_lifted.len());
Expand Down Expand Up @@ -1812,7 +1811,7 @@ impl Circuits {
// outputs the uncorrected lifted shares and the injected correction values
pub fn lift_mpc(
&mut self,
shares: &[ChunkShare<u16>],
shares: &[ChunkShareView<u16>],
xa: &mut [ChunkShareView<u32>],
injected: &mut [ChunkShareView<u16>],
streams: &[CudaStream],
Expand Down Expand Up @@ -2226,8 +2225,8 @@ impl Circuits {
// Result is in the first bit of the result buffer
pub fn compare_threshold_masked_many(
&mut self,
code_dots: &[ChunkShare<u16>],
mask_dots: &[ChunkShare<u16>],
code_dots: &[ChunkShareView<u16>],
mask_dots: &[ChunkShareView<u16>],
streams: &[CudaStream],
) {
assert_eq!(self.n_devices, code_dots.len());
Expand Down Expand Up @@ -2256,8 +2255,8 @@ impl Circuits {
// Result is in the lowest bit of the result buffer on the first gpu
pub fn compare_threshold_masked_many_with_or_tree(
&mut self,
code_dots: &[ChunkShare<u16>],
mask_dots: &[ChunkShare<u16>],
code_dots: &[ChunkShareView<u16>],
mask_dots: &[ChunkShareView<u16>],
streams: &[CudaStream],
) {
self.compare_threshold_masked_many(code_dots, mask_dots, streams);
Expand Down
2 changes: 1 addition & 1 deletion tests/extract_msb_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ async fn test_extract_msb_mod() -> eyre::Result<()> {
let mut x = to_view(&x_);
let correction_ = party.allocate_buffer::<u16>(INPUTS_PER_GPU_SIZE * 2);
let correction = to_view(&correction_);
let code_gpu = code_gpu.clone();
let code_gpu = code_gpu.iter().map(|x| x.as_view()).collect::<Vec<_>>();

let now = Instant::now();
party.lift_mul_sub(&mut x, &correction, &code_gpu, &streams);
Expand Down
2 changes: 1 addition & 1 deletion tests/lift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async fn test_lift() -> eyre::Result<()> {
let mut x = to_view(&x_);
let correction_ = party.allocate_buffer::<u16>(INPUTS_PER_GPU_SIZE * 2);
let mut correction = to_view(&correction_);
let mask_gpu = mask_gpu.clone();
let mask_gpu = mask_gpu.iter().map(|x| x.as_view()).collect::<Vec<_>>();

let now = Instant::now();
party.lift_mpc(&mask_gpu, &mut x, &mut correction, &streams);
Expand Down
4 changes: 2 additions & 2 deletions tests/threshold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ async fn test_threshold() -> eyre::Result<()> {
for _ in 0..10 {
server_tasks.check_tasks();

let code_gpu = code_gpu.clone();
let mask_gpu = mask_gpu.clone();
let code_gpu = code_gpu.iter().map(|x| x.as_view()).collect::<Vec<_>>();
let mask_gpu = mask_gpu.iter().map(|x| x.as_view()).collect::<Vec<_>>();

let now = Instant::now();
party.compare_threshold_masked_many(&code_gpu, &mask_gpu, &streams);
Expand Down
4 changes: 2 additions & 2 deletions tests/threshold_and_or_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ async fn test_threshold_and_or_tree() -> eyre::Result<()> {
for _ in 0..10 {
server_tasks.check_tasks();

let code_gpu = code_gpu.clone();
let mask_gpu = mask_gpu.clone();
let code_gpu = code_gpu.iter().map(|x| x.as_view()).collect::<Vec<_>>();
let mask_gpu = mask_gpu.iter().map(|x| x.as_view()).collect::<Vec<_>>();

let now = Instant::now();
party.compare_threshold_masked_many_with_or_tree(&code_gpu, &mask_gpu, &streams);
Expand Down

0 comments on commit c480df7

Please sign in to comment.