diff --git a/backends/cuda-shared/ceed-cuda-shared-basis.c b/backends/cuda-shared/ceed-cuda-shared-basis.c index 5991559cd3..137e9f718c 100644 --- a/backends/cuda-shared/ceed-cuda-shared-basis.c +++ b/backends/cuda-shared/ceed-cuda-shared-basis.c @@ -628,11 +628,21 @@ int CeedBasisCreateH1_Cuda_shared(CeedElemTopology topo, CeedInt dim, CeedInt nu CeedBasis_Cuda_shared *data; CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); - CeedCallBackend(CeedCalloc(1, &data)); - // Check max sizes - CeedCheck(dim <= 3, ceed, CEED_ERROR_BACKEND, "Backend does not implement nontensor bases with dim > 3"); - CeedCheck(num_nodes * num_qpts * dim < 52 * 52 * 3, ceed, CEED_ERROR_BACKEND, "Backend does not implement nontensor bases with P * Q this large"); + // Check shared memory size + { + Ceed_Cuda *cuda_data; + + CeedCallBackend(CeedGetData(ceed, &cuda_data)); + if (((size_t)num_nodes * (size_t)num_qpts * (size_t)dim + (size_t)CeedIntMax(num_nodes, num_qpts)) * sizeof(CeedScalar) > + cuda_data->device_prop.sharedMemPerBlock) { + CeedCallBackend(CeedBasisCreateH1Fallback(ceed, topo, dim, num_nodes, num_qpts, interp, grad, q_ref, q_weight, basis)); + CeedCallBackend(CeedDestroy(&ceed)); + return CEED_ERROR_SUCCESS; + } + } + + CeedCallBackend(CeedCalloc(1, &data)); // Copy basis data to GPU CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); diff --git a/backends/hip-shared/ceed-hip-shared-basis.c b/backends/hip-shared/ceed-hip-shared-basis.c index 144af79c2f..d65c065ec2 100644 --- a/backends/hip-shared/ceed-hip-shared-basis.c +++ b/backends/hip-shared/ceed-hip-shared-basis.c @@ -697,6 +697,19 @@ int CeedBasisCreateH1_Hip_shared(CeedElemTopology topo, CeedInt dim, CeedInt num CeedBasis_Hip_shared *data; CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); + + // Check shared memory size + { + Ceed_Hip *hip_data; + + CeedCallBackend(CeedGetData(ceed, &hip_data)); + if (((size_t)num_nodes * (size_t)num_qpts * (size_t)dim + (size_t)CeedIntMax(num_nodes, num_qpts)) * sizeof(CeedScalar) > + hip_data->device_prop.sharedMemPerBlock) { + CeedCallBackend(CeedBasisCreateH1Fallback(ceed, topo, dim, num_nodes, num_qpts, interp, grad, q_ref, q_weight, basis)); + return CEED_ERROR_SUCCESS; + } + } + CeedCallBackend(CeedCalloc(1, &data)); // Copy basis data to GPU diff --git a/include/ceed/backend.h b/include/ceed/backend.h index 3884501f4b..7f686660ed 100644 --- a/include/ceed/backend.h +++ b/include/ceed/backend.h @@ -338,6 +338,9 @@ CEED_EXTERN int CeedBasisGetFESpace(CeedBasis basis, CeedFESpace *fe_space); CEED_EXTERN int CeedBasisGetTopologyDimension(CeedElemTopology topo, CeedInt *dim); CEED_EXTERN int CeedBasisGetTensorContract(CeedBasis basis, CeedTensorContract *contract); CEED_EXTERN int CeedBasisSetTensorContract(CeedBasis basis, CeedTensorContract contract); +CEED_EXTERN int CeedBasisCreateH1Fallback(Ceed ceed, CeedElemTopology topo, CeedInt num_comp, CeedInt num_nodes, CeedInt nqpts, + const CeedScalar *interp, const CeedScalar *grad, const CeedScalar *q_ref, const CeedScalar *q_weights, + CeedBasis basis); CEED_EXTERN int CeedTensorContractCreate(Ceed ceed, CeedTensorContract *contract); CEED_EXTERN int CeedTensorContractApply(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *__restrict__ t, diff --git a/interface/ceed-basis.c b/interface/ceed-basis.c index 4a4f5fb180..3d5d51107f 100644 --- a/interface/ceed-basis.c +++ b/interface/ceed-basis.c @@ -600,6 +600,42 @@ static int CeedBasisApplyAtPoints_Core(CeedBasis basis, bool apply_add, CeedInt /// @addtogroup CeedBasisBackend /// @{ +/** + @brief Fallback to a reference implementation for a non tensor-product basis for \f$H^1\f$ discretizations. + This function may only be called inside of a backend `BasisCreateH1` function. + This is used by a backend when the specific parameters for a `CeedBasis` exceed the backend's support, such as + when a `interp` and `grad` matrices require too many bytes to fit into shared memory on a GPU. + + @param[in] ceed `Ceed` object used to create the `CeedBasis` + @param[in] topo Topology of element, e.g. hypercube, simplex, etc + @param[in] num_comp Number of field components (1 for scalar fields) + @param[in] num_nodes Total number of nodes + @param[in] num_qpts Total number of quadrature points + @param[in] interp Row-major (`num_qpts * num_nodes`) matrix expressing the values of nodal basis functions at quadrature points + @param[in] grad Row-major (`dim * num_qpts * num_nodes`) matrix expressing derivatives of nodal basis functions at quadrature points + @param[in] q_ref Array of length `num_qpts * dim` holding the locations of quadrature points on the reference element + @param[in] q_weight Array of length `num_qpts` holding the quadrature weights on the reference element + @param[out] basis Newly created `CeedBasis` + + @return An error code: 0 - success, otherwise - failure + + @ref User +**/ +int CeedBasisCreateH1Fallback(Ceed ceed, CeedElemTopology topo, CeedInt num_comp, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, + const CeedScalar *grad, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) { + CeedInt P = num_nodes, Q = num_qpts, dim = 0; + Ceed delegate; + + CeedCall(CeedGetObjectDelegate(ceed, &delegate, "Basis")); + CeedCheck(delegate, ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement BasisCreateH1"); + + CeedCall(CeedReferenceCopy(delegate, &(basis)->ceed)); + CeedCall(CeedBasisGetTopologyDimension(topo, &dim)); + CeedCall(delegate->BasisCreateH1(topo, dim, P, Q, interp, grad, q_ref, q_weight, basis)); + CeedCall(CeedDestroy(&delegate)); + return CEED_ERROR_SUCCESS; +} + /** @brief Return collocated gradient matrix @@ -1493,7 +1529,7 @@ int CeedBasisCreateTensorH1Lagrange(Ceed ceed, CeedInt dim, CeedInt num_comp, Ce @param[in] num_qpts Total number of quadrature points @param[in] interp Row-major (`num_qpts * num_nodes`) matrix expressing the values of nodal basis functions at quadrature points @param[in] grad Row-major (`dim * num_qpts * num_nodes`) matrix expressing derivatives of nodal basis functions at quadrature points - @param[in] q_ref Array of length `num_qpts` * dim holding the locations of quadrature points on the reference element + @param[in] q_ref Array of length `num_qpts * dim` holding the locations of quadrature points on the reference element @param[in] q_weight Array of length `num_qpts` holding the quadrature weights on the reference element @param[out] basis Address of the variable where the newly created `CeedBasis` will be stored diff --git a/tests/t319-basis.c b/tests/t319-basis.c index e34e296fca..c314cb2e82 100644 --- a/tests/t319-basis.c +++ b/tests/t319-basis.c @@ -116,7 +116,7 @@ int main(int argc, char **argv) { for (CeedInt dim = 1; dim <= 3; dim++) { CeedVector x_corners, x_from, x_to, u_from, u_to, du_to; CeedBasis basis_x, basis_from, basis_to, basis_project; - CeedInt p_from = 3, p_to = 4, q = 4, x_dim = CeedIntPow(2, dim), p_from_dim = CeedIntPow(p_from, dim), p_to_dim = CeedIntPow(p_to, dim); + CeedInt p_from = 4, p_to = 5, q = 6, x_dim = CeedIntPow(2, dim), p_from_dim = CeedIntPow(p_from, dim), p_to_dim = CeedIntPow(p_to, dim); CeedVectorCreate(ceed, x_dim * dim, &x_corners); {