Skip to content

Commit

Permalink
add petsc binary reader to csr
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Nov 15, 2024
1 parent dc8cfeb commit 50ae939
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
76 changes: 76 additions & 0 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

#include "ginkgo/core/matrix/csr.hpp"

#include <cstring>
#include <memory>

#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/executor.hpp>
Expand Down Expand Up @@ -100,6 +103,18 @@ GKO_REGISTER_OPERATION(aos_to_soa, components::aos_to_soa);
} // anonymous namespace
} // namespace csr

namespace {


// utilities for error checking
#define GKO_CHECK_STREAM(_stream, _message) \
if ((_stream).fail()) { \
throw GKO_STREAM_ERROR(_message); \
}


} // namespace


template <typename ValueType, typename IndexType>
std::unique_ptr<const Csr<ValueType, IndexType>>
Expand Down Expand Up @@ -531,6 +546,67 @@ void Csr<ValueType, IndexType>::read(const mat_data& data)
}


template <typename ValueType, typename IndexType>
template <typename InputValueType, typename InputIndexType>
void Csr<ValueType, IndexType>::read_petsc_binary(
std::istream& is, InputValueType dummy_valuetype,
InputIndexType dummy_indextype)
{
constexpr auto indextype_size = sizeof(InputIndexType);
constexpr auto valuetype_size = sizeof(InputValueType);
std::vector<char> header(4 * indextype_size);
GKO_CHECK_STREAM(is.read(header.data(), 16), "failed reading header");
InputIndexType matid{};
InputIndexType num_rows{};
InputIndexType num_cols{};
InputIndexType num_entries{};
std::memcpy(&matid, &header[0], indextype_size);
std::memcpy(&num_rows, &header[4], indextype_size);
std::memcpy(&num_cols, &header[8], indextype_size);
std::memcpy(&num_entries, &header[12], indextype_size);
auto size = gko::dim<2>(num_rows, num_cols);
auto exec = this->get_executor();
this->set_size(size);
auto row_ptrs = array<InputIndexType>{exec->get_master(), size[0] + 1};
// First fill with nnz_per_row, as PETSc binary format writes nnz_per_row
// instead of the prefix summed values
{
std::vector<char> block(num_rows * indextype_size);
GKO_CHECK_STREAM(is.read(block.data(), num_rows * indextype_size),
"failed reading nnz per row data");
std::memcpy(&row_ptrs.get_data()[0], &block[0],
num_rows * indextype_size);
exec->get_master()->run(csr::make_prefix_sum_nonnegative(
row_ptrs.get_data(), num_rows + 1));
GKO_ASSERT(num_entries == row_ptrs.get_data()[num_rows]);
}
auto col_idxs = array<InputIndexType>{exec->get_master(), num_entries};
{
std::vector<char> block(num_entries * indextype_size);
GKO_CHECK_STREAM(is.read(block.data(), num_entries * indextype_size),
"failed reading col_idxs data");
std::memcpy(&col_idxs.get_data()[0], &block[0],
num_entries * indextype_size);
}
auto values = array<InputValueType>{exec->get_master(), num_entries};
{
std::vector<char> block(num_entries * valuetype_size);
GKO_CHECK_STREAM(is.read(block.data(), num_entries * valuetype_size),
"failed reading values data");
std::memcpy(&values.get_data()[0], &block[0],
num_entries * valuetype_size);
}

this->row_ptrs_.resize_and_reset(size[0] + 1);
this->col_idxs_.resize_and_reset(num_entries);
this->values_.resize_and_reset(num_entries);
this->row_ptrs_ = row_ptrs;
this->col_idxs_ = col_idxs;
this->values_ = values;
this->make_srow();
}


template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::read(const device_mat_data& data)
{
Expand Down
4 changes: 4 additions & 0 deletions include/ginkgo/core/matrix/csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,10 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,

void read(device_mat_data&& data) override;

template <typename InputValueType, typename InputIndexType>
void read_petsc_binary(std::istream& is, InputValueType dummy_valuetype,
InputIndexType dummy_indextype);

void write(mat_data& data) const override;

std::unique_ptr<LinOp> transpose() const override;
Expand Down

0 comments on commit 50ae939

Please sign in to comment.