Skip to content

Commit

Permalink
gpu - fallback if nontensor shared uses too much mem
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Jan 14, 2025
1 parent d01feaa commit bfe4ad2
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 6 deletions.
17 changes: 13 additions & 4 deletions backends/cuda-shared/ceed-cuda-shared-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -628,11 +628,20 @@ int CeedBasisCreateH1_Cuda_shared(CeedElemTopology topo, CeedInt dim, CeedInt nu
CeedBasis_Cuda_shared *data;

CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
CeedCallBackend(CeedCalloc(1, &data));

// Check max sizes
CeedCheck(dim <= 3, ceed, CEED_ERROR_BACKEND, "Backend does not implement nontensor bases with dim > 3");
CeedCheck(num_nodes * num_qpts * dim < 52 * 52 * 3, ceed, CEED_ERROR_BACKEND, "Backend does not implement nontensor bases with P * Q this large");
// Check shared memory size
{
Ceed_Cuda *cuda_data;

CeedCallBackend(CeedGetData(ceed, &cuda_data));
if (((size_t)num_nodes * (size_t)num_qpts * (size_t)dim + (size_t)CeedIntMax(num_nodes, num_qpts)) * sizeof(CeedScalar) >
cuda_data->device_prop.sharedMemPerBlock) {
CeedCallBackend(CeedBasisCreateH1Fallback(ceed, topo, dim, num_nodes, num_qpts, interp, grad, q_ref, q_weight, basis));
return CEED_ERROR_SUCCESS;
}
}

CeedCallBackend(CeedCalloc(1, &data));

// Copy basis data to GPU
CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
Expand Down
13 changes: 13 additions & 0 deletions backends/hip-shared/ceed-hip-shared-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,19 @@ int CeedBasisCreateH1_Hip_shared(CeedElemTopology topo, CeedInt dim, CeedInt num
CeedBasis_Hip_shared *data;

CeedCallBackend(CeedBasisGetCeed(basis, &ceed));

// Check shared memory size
{
Ceed_Hip *hip_data;

CeedCallBackend(CeedGetData(ceed, &hip_data));
if (((size_t)num_nodes * (size_t)num_qpts * (size_t)dim + (size_t)CeedIntMax(num_nodes, num_qpts)) * sizeof(CeedScalar) >
hip_data->device_prop.sharedMemPerBlock) {
CeedCallBackend(CeedBasisCreateH1Fallback(ceed, topo, dim, num_nodes, num_qpts, interp, grad, q_ref, q_weight, basis));
return CEED_ERROR_SUCCESS;
}
}

CeedCallBackend(CeedCalloc(1, &data));

// Copy basis data to GPU
Expand Down
3 changes: 3 additions & 0 deletions include/ceed/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ CEED_EXTERN int CeedBasisGetFESpace(CeedBasis basis, CeedFESpace *fe_space);
CEED_EXTERN int CeedBasisGetTopologyDimension(CeedElemTopology topo, CeedInt *dim);
CEED_EXTERN int CeedBasisGetTensorContract(CeedBasis basis, CeedTensorContract *contract);
CEED_EXTERN int CeedBasisSetTensorContract(CeedBasis basis, CeedTensorContract contract);
CEED_EXTERN int CeedBasisCreateH1Fallback(Ceed ceed, CeedElemTopology topo, CeedInt num_comp, CeedInt num_nodes, CeedInt nqpts,
const CeedScalar *interp, const CeedScalar *grad, const CeedScalar *q_ref, const CeedScalar *q_weights,
CeedBasis basis);

CEED_EXTERN int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract);
CEED_EXTERN int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *__restrict__ t,
Expand Down
38 changes: 37 additions & 1 deletion interface/ceed-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,42 @@ static int CeedBasisApplyAtPoints_Core(CeedBasis basis, bool apply_add, CeedInt
/// @addtogroup CeedBasisBackend
/// @{

/**
@brief Fallback to a reference implementation for a non tensor-product basis for \f$H^1\f$ discretizations.
This function may only be called inside of a backend `BasisCreateH1` function.
This is used by a backend when the specific parameters for a `CeedBasis` exceed the backend's support, such as
when a `interp` and `grad` matrices require too many bytes to fit into shared memory on a GPU.
@param[in] ceed `Ceed` object used to create the `CeedBasis`
@param[in] topo Topology of element, e.g. hypercube, simplex, etc
@param[in] num_comp Number of field components (1 for scalar fields)
@param[in] num_nodes Total number of nodes
@param[in] num_qpts Total number of quadrature points
@param[in] interp Row-major (`num_qpts * num_nodes`) matrix expressing the values of nodal basis functions at quadrature points
@param[in] grad Row-major (`dim * num_qpts * num_nodes`) matrix expressing derivatives of nodal basis functions at quadrature points
@param[in] q_ref Array of length `num_qpts * dim` holding the locations of quadrature points on the reference element
@param[in] q_weight Array of length `num_qpts` holding the quadrature weights on the reference element
@param[out] basis Newly created `CeedBasis`
@return An error code: 0 - success, otherwise - failure
@ref User
**/
int CeedBasisCreateH1Fallback(Ceed ceed, CeedElemTopology topo, CeedInt num_comp, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
const CeedScalar *grad, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
CeedInt P = num_nodes, Q = num_qpts, dim = 0;
Ceed delegate;

CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Basis"));
CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement BasisCreateH1");

CeedCall(CeedReferenceCopy(delegate, &(basis)->ceed));
CeedCall(CeedBasisGetTopologyDimension(topo, &dim));
CeedCall(delegate->BasisCreateH1(topo, dim, P, Q, interp, grad, q_ref, q_weight, basis));
CeedCall(CeedDestroy(&delegate));
return CEED_ERROR_SUCCESS;
}

/**
@brief Return collocated gradient matrix
Expand Down Expand Up @@ -1493,7 +1529,7 @@ int CeedBasisCreateTensorH1Lagrange(Ceed ceed, CeedInt dim, CeedInt num_comp, Ce
@param[in] num_qpts Total number of quadrature points
@param[in] interp Row-major (`num_qpts * num_nodes`) matrix expressing the values of nodal basis functions at quadrature points
@param[in] grad Row-major (`dim * num_qpts * num_nodes`) matrix expressing derivatives of nodal basis functions at quadrature points
@param[in] q_ref Array of length `num_qpts` * dim holding the locations of quadrature points on the reference element
@param[in] q_ref Array of length `num_qpts * dim` holding the locations of quadrature points on the reference element
@param[in] q_weight Array of length `num_qpts` holding the quadrature weights on the reference element
@param[out] basis Address of the variable where the newly created `CeedBasis` will be stored
Expand Down
2 changes: 1 addition & 1 deletion tests/t319-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ int main(int argc, char **argv) {
for (CeedInt dim = 1; dim <= 3; dim++) {
CeedVector x_corners, x_from, x_to, u_from, u_to, du_to;
CeedBasis basis_x, basis_from, basis_to, basis_project;
CeedInt p_from = 3, p_to = 4, q = 4, x_dim = CeedIntPow(2, dim), p_from_dim = CeedIntPow(p_from, dim), p_to_dim = CeedIntPow(p_to, dim);
CeedInt p_from = 4, p_to = 5, q = 6, x_dim = CeedIntPow(2, dim), p_from_dim = CeedIntPow(p_from, dim), p_to_dim = CeedIntPow(p_to, dim);

CeedVectorCreate(ceed, x_dim * dim, &x_corners);
{
Expand Down

0 comments on commit bfe4ad2

Please sign in to comment.