Skip to content

Commit

Permalink
Make dataset serialization optional
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks committed Dec 5, 2023
1 parent ce12715 commit 0287c94
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions cpp/include/raft/neighbors/brute_force_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ auto static constexpr serialization_version = 0;
*
*/
template <typename T>
void serialize(raft::resources const& handle, std::ostream& os, const index<T>& index)
void serialize(raft::resources const& handle,
std::ostream& os,
const index<T>& index,
bool include_dataset = true)
{
RAFT_LOG_DEBUG(
"Saving brute force index, size %zu, dim %u", static_cast<size_t>(index.size()), index.dim());
Expand All @@ -68,7 +71,8 @@ void serialize(raft::resources const& handle, std::ostream& os, const index<T>&
serialize_scalar(handle, os, index.dim());
serialize_scalar(handle, os, index.metric());
serialize_scalar(handle, os, index.metric_arg());
serialize_mdspan(handle, os, index.dataset());
serialize_scalar(handle, os, include_dataset);
if (include_dataset) { serialize_mdspan(handle, os, index.dataset()); }
auto has_norms = index.has_norms();
serialize_scalar(handle, os, has_norms);
if (has_norms) { serialize_mdspan(handle, os, index.norms()); }
Expand Down Expand Up @@ -100,11 +104,14 @@ void serialize(raft::resources const& handle, std::ostream& os, const index<T>&
*
*/
template <typename T>
void serialize(raft::resources const& handle, const std::string& filename, const index<T>& index)
void serialize(raft::resources const& handle,
const std::string& filename,
const index<T>& index,
bool include_dataset = true)
{
auto os = std::ofstream{filename, std::ios::out | std::ios::binary};
RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str());
serialize(handle, os, index);
serialize(handle, os, index, include_dataset);
}

/**
Expand Down Expand Up @@ -147,7 +154,8 @@ auto deserialize(raft::resources const& handle, std::istream& is)
auto metric_arg = deserialize_scalar<T>(handle, is);

auto dataset_storage = raft::make_host_matrix<T>(rows, dim);
deserialize_mdspan(handle, is, dataset_storage.view());
auto include_dataset = deserialize_scalar<bool>(handle, is);
if (include_dataset) { deserialize_mdspan(handle, is, dataset_storage.view()); }

auto has_norms = deserialize_scalar<bool>(handle, is);
auto norms_storage = has_norms ? std::optional{raft::make_host_vector<T, std::int64_t>(rows)}
Expand Down

0 comments on commit 0287c94

Please sign in to comment.