Skip to content

Commit

Permalink
Update hip basis code to conform to vector interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
zatkins-dev committed Jan 24, 2025
1 parent bddbe1e commit 1972990
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 59 deletions.
55 changes: 13 additions & 42 deletions backends/hip-ref/ceed-hip-ref-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,12 @@ static int CeedBasisApplyCore_Hip(CeedBasis basis, bool apply_add, const CeedInt
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes, num_qpts;
CeedSize length;
Ceed_Hip *hip_data;

CeedCallBackend(CeedGetData(ceed, &hip_data));

CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp));
if (hip_data->has_unified_addressing) memset(d_v, 0, length * sizeof(CeedScalar));
else CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
else {

This comment has been minimized.

Copy link
@jeremylt

jeremylt Jan 24, 2025

Member

minor style nit, I've been using

if () {

} else {

}

if either block needs braces per our conventions

// Clear v for transpose operation
if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0));
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
}

CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
CeedCallBackend(CeedBasisGetDimension(basis, &dim));

Expand Down Expand Up @@ -208,20 +196,10 @@ static int CeedBasisApplyAtPointsCore_Hip(CeedBasis basis, bool apply_add, const
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes;
CeedSize length;

CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
length =
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
if (hip_data->has_unified_addressing) memset(d_v, 0, length * sizeof(CeedScalar));
else CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
else {
// Clear v for transpose operation
if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0));
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
}

// Basis action
Expand Down Expand Up @@ -293,17 +271,10 @@ static int CeedBasisApplyNonTensorCore_Hip(CeedBasis basis, bool apply_add, cons
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedSize length;
Ceed_Hip *hip_data;

CeedCallBackend(CeedGetData(ceed, &hip_data));
CeedCallBackend(CeedVectorGetLength(v, &length));
if (hip_data->has_unified_addressing) memset(d_v, 0, length * sizeof(CeedScalar));
else CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
else {
// Clear v for transpose operation
if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0));
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
}

// Apply basis operation
Expand Down
21 changes: 4 additions & 17 deletions backends/hip-shared/ceed-hip-shared-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -366,23 +366,10 @@ static int CeedBasisApplyAtPointsCore_Hip_shared(CeedBasis basis, bool apply_add
if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (is_transpose && !apply_add) {
CeedInt num_comp, q_comp, num_nodes;
CeedSize length;
Ceed_Hip *hip_data;

CeedCallBackend(CeedGetData(ceed, &hip_data));

CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
length =
(CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
if (hip_data->has_unified_addressing) memset(d_v, 0, length * sizeof(CeedScalar));
else CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
else {

This comment has been minimized.

Copy link
@jeremylt

jeremylt Jan 24, 2025

Member

Can you do this for CUDA too?

// Clear v for transpose operation
if (is_transpose) CeedCallBackend(CeedVectorSetValue(v, 0.0));
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
}

// Basis action
Expand Down

0 comments on commit 1972990

Please sign in to comment.