Skip to content

Commit

Permalink
SYCL. Refactor on-device data structures (#10898)
Browse files Browse the repository at this point in the history
  • Loading branch information
razdoburdin authored Oct 18, 2024
1 parent acb64f7 commit 1b06da1
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 267 deletions.
4 changes: 2 additions & 2 deletions plugin/sycl/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const auto* pgh = gpair_device.DataConst();
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.index.Offset();
const uint32_t* offsets = gmat.cut.cut_ptrs_.ConstDevicePointer();
const size_t nbins = gmat.nbins;

const size_t max_work_group_size =
Expand Down Expand Up @@ -210,7 +210,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
const GradientPair::ValueT* pgh =
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.index.Offset();
const uint32_t* offsets = gmat.cut.cut_ptrs_.ConstDevicePointer();
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
const size_t nbins = gmat.nbins;

Expand Down
2 changes: 1 addition & 1 deletion plugin/sycl/common/partition_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ inline ::sycl::event PartitionSparseKernel(::sycl::queue* qu,
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const size_t* rid = rid_span.begin;
const size_t range_size = rid_span.Size();
const uint32_t* cut_ptrs = gmat.cut_device.Ptrs().DataConst();
const uint32_t* cut_ptrs = gmat.cut.cut_ptrs_.ConstDevicePointer();

size_t* p_rid_buf = rid_buf->data();
return qu->submit([&](::sycl::handler& cgh) {
Expand Down
63 changes: 0 additions & 63 deletions plugin/sycl/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,69 +224,6 @@ class USMVector {
std::shared_ptr<T> data_;
};

/* Wrapper for DMatrix which stores all batches in a single USM buffer */
struct DeviceMatrix {
DMatrix* p_mat; // Pointer to the original matrix on the host
::sycl::queue* qu_;
USMVector<size_t, MemoryType::on_device> row_ptr;
USMVector<Entry, MemoryType::on_device> data;
size_t total_offset;

DeviceMatrix() = default;

void Init(::sycl::queue* qu, DMatrix* dmat) {
qu_ = qu;
p_mat = dmat;

size_t num_row = 0;
size_t num_nonzero = 0;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
num_nonzero += batch.data.Size();
num_row += batch.Size();
}

row_ptr.Resize(qu_, num_row + 1);
size_t* rows = row_ptr.Data();
data.Resize(qu_, num_nonzero);

size_t data_offset = 0;
::sycl::event event;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
const auto& data_vec = batch.data.ConstHostVector();
const auto& offset_vec = batch.offset.ConstHostVector();
size_t batch_size = batch.Size();
if (batch_size > 0) {
const auto base_rowid = batch.base_rowid;
event = qu->memcpy(row_ptr.Data() + base_rowid, offset_vec.data(),
sizeof(size_t) * batch_size, event);
if (base_rowid > 0) {
qu->submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::id<1> pid) {
int row_id = pid[0];
rows[row_id] += base_rowid;
});
});
}
event = qu->memcpy(data.Data() + data_offset, data_vec.data(),
sizeof(Entry) * offset_vec[batch_size], event);
data_offset += offset_vec[batch_size];
qu->wait();
}
}
qu_->submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.single_task<>([=] {
rows[num_row] = data_offset;
});
});
qu_->wait();
total_offset = data_offset;
}

~DeviceMatrix() {
}
};
} // namespace sycl
} // namespace xgboost

Expand Down
123 changes: 61 additions & 62 deletions plugin/sycl/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,51 +48,53 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) {
}
}

template <typename BinIdxType>
template <typename BinIdxType, bool isDense>
void GHistIndexMatrix::SetIndexData(::sycl::queue* qu,
BinIdxType* index_data,
const DeviceMatrix &dmat,
DMatrix *dmat,
size_t nbins,
size_t row_stride,
uint32_t* offsets) {
size_t row_stride) {
if (nbins == 0) return;
const xgboost::Entry *data_ptr = dmat.data.DataConst();
const bst_idx_t *offset_vec = dmat.row_ptr.DataConst();
const size_t num_rows = dmat.row_ptr.Size() - 1;
const bst_float* cut_values = cut_device.Values().DataConst();
const uint32_t* cut_ptrs = cut_device.Ptrs().DataConst();
size_t* hit_count_ptr = hit_count_buff.Data();

// Sparse case only
if (!offsets) {
// sort_buff has type uint8_t
sort_buff.Resize(qu, num_rows * row_stride * sizeof(BinIdxType));
}
const bst_float* cut_values = cut.cut_values_.ConstDevicePointer();
const uint32_t* cut_ptrs = cut.cut_ptrs_.ConstDevicePointer();
size_t* hit_count_ptr = hit_count.DevicePointer();

BinIdxType* sort_data = reinterpret_cast<BinIdxType*>(sort_buff.Data());

auto event = qu->submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) {
const size_t i = pid.get_id(0);
const size_t ibegin = offset_vec[i];
const size_t iend = offset_vec[i + 1];
const size_t size = iend - ibegin;
const size_t start = i * row_stride;
for (bst_uint j = 0; j < size; ++j) {
uint32_t idx = SearchBin(cut_values, cut_ptrs, data_ptr[ibegin + j]);
index_data[start + j] = offsets ? idx - offsets[j] : idx;
AtomicRef<size_t> hit_count_ref(hit_count_ptr[idx]);
hit_count_ref.fetch_add(1);
::sycl::event event;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
for (auto &batch : dmat->GetBatches<SparsePage>()) {
const xgboost::Entry *data_ptr = batch.data.ConstDevicePointer();
const bst_idx_t *offset_vec = batch.offset.ConstDevicePointer();
size_t batch_size = batch.Size();
if (batch_size > 0) {
const auto base_rowid = batch.base_rowid;
event = qu->submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::item<1> pid) {
const size_t i = pid.get_id(0);
const size_t ibegin = offset_vec[i];
const size_t iend = offset_vec[i + 1];
const size_t size = iend - ibegin;
const size_t start = (i + base_rowid) * row_stride;
for (bst_uint j = 0; j < size; ++j) {
uint32_t idx = SearchBin(cut_values, cut_ptrs, data_ptr[ibegin + j]);
index_data[start + j] = isDense ? idx - cut_ptrs[j] : idx;
AtomicRef<size_t> hit_count_ref(hit_count_ptr[idx]);
hit_count_ref.fetch_add(1);
}
if constexpr (!isDense) {
// Sparse case only
mergeSort<BinIdxType>(index_data + start, index_data + start + size, sort_data + start);
for (bst_uint j = size; j < row_stride; ++j) {
index_data[start + j] = nbins;
}
}
});
});
}
if (!offsets) {
// Sparse case only
mergeSort<BinIdxType>(index_data + start, index_data + start + size, sort_data + start);
for (bst_uint j = size; j < row_stride; ++j) {
index_data[start + j] = nbins;
}
}
});
});
qu->memcpy(hit_count.data(), hit_count_ptr, nbins * sizeof(size_t), event);
}
}
qu->wait();
}

Expand All @@ -112,63 +114,60 @@ void GHistIndexMatrix::ResizeIndex(size_t n_index, bool isDense) {

void GHistIndexMatrix::Init(::sycl::queue* qu,
Context const * ctx,
const DeviceMatrix& p_fmat_device,
DMatrix *dmat,
int max_bins) {
nfeatures = p_fmat_device.p_mat->Info().num_col_;
nfeatures = dmat->Info().num_col_;

cut = xgboost::common::SketchOnDMatrix(ctx, p_fmat_device.p_mat, max_bins);
cut_device.Init(qu, cut);
cut = xgboost::common::SketchOnDMatrix(ctx, dmat, max_bins);
cut.SetDevice(ctx->Device());

max_num_bins = max_bins;
const uint32_t nbins = cut.Ptrs().back();
this->nbins = nbins;
hit_count.resize(nbins, 0);
hit_count_buff.Resize(qu, nbins, 0);

this->p_fmat = p_fmat_device.p_mat;
const bool isDense = p_fmat_device.p_mat->IsDense();
hit_count.SetDevice(ctx->Device());
hit_count.Resize(nbins, 0);

this->p_fmat = dmat;
const bool isDense = dmat->IsDense();
this->isDense_ = isDense;

index.setQueue(qu);

row_stride = 0;
for (const auto& batch : p_fmat_device.p_mat->GetBatches<SparsePage>()) {
size_t n_rows = 0;
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
const auto& row_offset = batch.offset.ConstHostVector();
batch.data.SetDevice(ctx->Device());
batch.offset.SetDevice(ctx->Device());
n_rows += batch.Size();
for (auto i = 1ull; i < row_offset.size(); i++) {
row_stride = std::max(row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
}
}

const size_t n_offsets = cut_device.Ptrs().Size() - 1;
const size_t n_rows = p_fmat_device.row_ptr.Size() - 1;
const size_t n_offsets = cut.cut_ptrs_.Size() - 1;
const size_t n_index = n_rows * row_stride;
ResizeIndex(n_index, isDense);

CHECK_GT(cut_device.Values().Size(), 0U);

uint32_t* offsets = nullptr;
if (isDense) {
index.ResizeOffset(n_offsets);
offsets = index.Offset();
qu->memcpy(offsets, cut_device.Ptrs().DataConst(),
sizeof(uint32_t) * n_offsets).wait_and_throw();
}
CHECK_GT(cut.cut_values_.Size(), 0U);

if (isDense) {
BinTypeSize curent_bin_size = index.GetBinTypeSize();
if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize) {
SetIndexData(qu, index.data<uint8_t>(), p_fmat_device, nbins, row_stride, offsets);
SetIndexData<uint8_t, true>(qu, index.data<uint8_t>(), dmat, nbins, row_stride);

} else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize) {
SetIndexData(qu, index.data<uint16_t>(), p_fmat_device, nbins, row_stride, offsets);
SetIndexData<uint16_t, true>(qu, index.data<uint16_t>(), dmat, nbins, row_stride);
} else {
CHECK_EQ(curent_bin_size, BinTypeSize::kUint32BinsTypeSize);
SetIndexData(qu, index.data<uint32_t>(), p_fmat_device, nbins, row_stride, offsets);
SetIndexData<uint32_t, true>(qu, index.data<uint32_t>(), dmat, nbins, row_stride);
}
/* For sparse DMatrix we have to store index of feature for each bin
in index field to chose right offset. So offset is nullptr and index is not reduced */
} else {
SetIndexData(qu, index.data<uint32_t>(), p_fmat_device, nbins, row_stride, offsets);
sort_buff.Resize(qu, n_rows * row_stride * sizeof(uint32_t));
SetIndexData<uint32_t, false>(qu, index.data<uint32_t>(), dmat, nbins, row_stride);
}
}

Expand Down
Loading

0 comments on commit 1b06da1

Please sign in to comment.