Skip to content

Commit

Permalink
Merge pull request #1385 from CEED/sjg/magma-transpose-opt
Browse files Browse the repository at this point in the history
Improve transpose basis performance for `magma` backend
  • Loading branch information
sebastiangrimberg authored Oct 27, 2023
2 parents 0263b5c + 3e5ab5d commit a0804ae
Show file tree
Hide file tree
Showing 12 changed files with 210 additions and 245 deletions.
26 changes: 0 additions & 26 deletions backends/magma/ceed-magma-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,6 @@ static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTranspose
else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (t_mode == CEED_TRANSPOSE) {
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
magmablas_slaset(MagmaFull, length, 1, 0.0, 0.0, (float *)d_v, length, data->queue);
} else {
magmablas_dlaset(MagmaFull, length, 1, 0.0, 0.0, (double *)d_v, length, data->queue);
}
ceed_magma_queue_sync(data->queue);
}

// Apply basis operation
switch (e_mode) {
case CEED_EVAL_INTERP: {
Expand Down Expand Up @@ -289,19 +276,6 @@ static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, Ceed
else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));

// Clear v for transpose operation
if (t_mode == CEED_TRANSPOSE) {
CeedSize length;

CeedCallBackend(CeedVectorGetLength(v, &length));
if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
magmablas_slaset(MagmaFull, length, 1, 0.0, 0.0, (float *)d_v, length, data->queue);
} else {
magmablas_dlaset(MagmaFull, length, 1, 0.0, 0.0, (double *)d_v, length, data->queue);
}
ceed_magma_queue_sync(data->queue);
}

// Apply basis operation
if (e_mode != CEED_EVAL_WEIGHT) {
const CeedScalar *d_b = NULL;
Expand Down
36 changes: 15 additions & 21 deletions include/ceed/jit-source/magma/magma-basis-grad-1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
// macros to abstract access of shared memory and reg. file
#define sT(i, j) sT[(j)*P + (i)]

//////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// grad basis action (1D)
template <typename T, int DIM, int NUM_COMP, int P, int Q>
static __device__ __inline__ void magma_grad_1d_device(const T *sT, magma_trans_t transT, T *sU[NUM_COMP], T *sV[NUM_COMP], const int tx) {
static __device__ __inline__ void magma_grad_1d_device(const T *sT, T *sU[NUM_COMP], T *sV[NUM_COMP], const int tx) {
// Assumptions
// 1. 1D threads of size max(P,Q)
// 2. sU[i] is 1xP: in shared memory
Expand All @@ -28,10 +28,9 @@ static __device__ __inline__ void magma_grad_1d_device(const T *sT, magma_trans_
// 6. Must sync before and after call
// 7. Note that the layout for U and V is different from 2D/3D problem

T rv;
if (tx < Q) {
for (int comp = 0; comp < NUM_COMP; comp++) {
rv = (transT == MagmaTrans) ? sV[comp][tx] : 0.0;
T rv = 0.0;
for (int i = 0; i < P; i++) {
rv += sU[comp][i] * sT(i, tx);
}
Expand All @@ -40,16 +39,15 @@ static __device__ __inline__ void magma_grad_1d_device(const T *sT, magma_trans_
}
}

//////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_1D)) __global__
void magma_gradn_1d_kernel(const CeedScalar *dTinterp, const CeedScalar *dTgrad, const CeedScalar *dU, const int estrdU, const int cstrdU,
const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) {
MAGMA_DEVICE_SHARED(CeedScalar, shared_data)

const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int elem_id = (blockIdx.x * blockDim.y) + ty;
magma_trans_t transT = MagmaNoTrans;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int elem_id = (blockIdx.x * blockDim.y) + ty;

if (elem_id >= nelem) return;

Expand All @@ -72,30 +70,29 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_

// read T
if (ty == 0) {
dread_T_gm2sm<BASIS_P, BASIS_Q>(tx, transT, dTgrad, sT);
read_T_notrans_gm2sm<BASIS_P, BASIS_Q>(tx, dTgrad, sT);
}

// read U
read_1d<CeedScalar, BASIS_P, BASIS_NUM_COMP>(dU, cstrdU, sU, tx);

__syncthreads();
magma_grad_1d_device<CeedScalar, BASIS_DIM, BASIS_NUM_COMP, BASIS_P, BASIS_Q>(sT, transT, sU, sV, tx);
magma_grad_1d_device<CeedScalar, BASIS_DIM, BASIS_NUM_COMP, BASIS_P, BASIS_Q>(sT, sU, sV, tx);
__syncthreads();

// write V
write_1d<CeedScalar, BASIS_Q, BASIS_NUM_COMP>(sV, dV, cstrdV, tx);
}

//////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_1D)) __global__
void magma_gradt_1d_kernel(const CeedScalar *dTinterp, const CeedScalar *dTgrad, const CeedScalar *dU, const int estrdU, const int cstrdU,
const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) {
MAGMA_DEVICE_SHARED(CeedScalar, shared_data)

const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int elem_id = (blockIdx.x * blockDim.y) + ty;
magma_trans_t transT = MagmaTrans;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int elem_id = (blockIdx.x * blockDim.y) + ty;

if (elem_id >= nelem) return;

Expand All @@ -118,17 +115,14 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_

// read T
if (ty == 0) {
dread_T_gm2sm<BASIS_Q, BASIS_P>(tx, transT, dTgrad, sT);
read_T_trans_gm2sm<BASIS_Q, BASIS_P>(tx, dTgrad, sT);
}

// read U
read_1d<CeedScalar, BASIS_Q, BASIS_NUM_COMP>(dU, cstrdU, sU, tx);

// read V
read_1d<CeedScalar, BASIS_P, BASIS_NUM_COMP>(dV, cstrdV, sV, tx);

__syncthreads();
magma_grad_1d_device<CeedScalar, BASIS_DIM, BASIS_NUM_COMP, BASIS_Q, BASIS_P>(sT, transT, sU, sV, tx);
magma_grad_1d_device<CeedScalar, BASIS_DIM, BASIS_NUM_COMP, BASIS_Q, BASIS_P>(sT, sU, sV, tx);
__syncthreads();

// write V
Expand Down
81 changes: 43 additions & 38 deletions include/ceed/jit-source/magma/magma-basis-grad-2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,32 @@
#define sT(i, j) sT[(j)*P + (i)]
#define sTmp(i, j, ldw) sTmp[(j) * (ldw) + (i)]

//////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// Helper function to add or set into V
template <typename T, bool Add>
struct magma_grad_2d_device_accumulate;

template <typename T>
struct magma_grad_2d_device_accumulate<T, true> {
static __device__ __inline__ void op(T &rV, const T &rTmp) { rV += rTmp; }
};

template <typename T>
struct magma_grad_2d_device_accumulate<T, false> {
static __device__ __inline__ void op(T &rV, const T &rTmp) { rV = rTmp; }
};

////////////////////////////////////////////////////////////////////////////////
// grad basis action (2D)
// This function is called two times at a higher level for 2D
// DIM_U -- for the size of rU[DIM_U * NUM_COMP * MAX_P_Q]
// DIM_V -- for the size of rV[DIM_V * NUM_COMP * MAX_P_Q]
// i_DIM -- the index of the outermost loop over dimensions in grad
// i_DIM_U -- which dim index of rU is accessed (always 0 for notrans, 0 or 1 for trans)
// i_DIM_V -- which dim index of rV is accessed (0 or 1 for notrans, always 0 for trans)
// the scalar beta is used to specify whether to accumulate to rV, or overwrite it
template <typename T, int DIM_U, int DIM_V, int NUM_COMP, int P, int Q, int rU_SIZE, int rV_SIZE, int i_DIM, int i_DIM_U, int i_DIM_V>
template <typename T, int DIM_U, int DIM_V, int NUM_COMP, int P, int Q, int rU_SIZE, int rV_SIZE, int i_DIM, int i_DIM_U, int i_DIM_V, bool ADD>
static __device__ __inline__ void magma_grad_2d_device(const T *sTinterp, const T *sTgrad, T rU[DIM_U][NUM_COMP][rU_SIZE],
T rV[DIM_V][NUM_COMP][rV_SIZE], T beta, const int tx, T rTmp, T *swork) {
T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx, T rTmp, T *swork) {
// Assumptions
// 0. This device routine applies grad for one dim only (i_DIM), so it should be called twice for 2D
// 1. 1D threads of size max(P,Q)
Expand Down Expand Up @@ -68,24 +82,22 @@ static __device__ __inline__ void magma_grad_2d_device(const T *sTinterp, const
for (int i = 0; i < P; i++) {
rTmp += sTmp(tx, i, sld) * sT(i, j);
}
rV[i_DIM_V][comp][j] *= beta;
rV[i_DIM_V][comp][j] += rTmp;
magma_grad_2d_device_accumulate<T, ADD>::op(rV[i_DIM_V][comp][j], rTmp);
}
}
__syncthreads();
} // loop over NUM_COMP
}

//////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_2D)) __global__
void magma_gradn_2d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU,
const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) {
MAGMA_DEVICE_SHARED(CeedScalar, shared_data)

const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int elem_id = (blockIdx.x * blockDim.y) + ty;
magma_trans_t transT = MagmaNoTrans;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int elem_id = (blockIdx.x * blockDim.y) + ty;

if (elem_id >= nelem) return;

Expand All @@ -105,40 +117,38 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_

// read T
if (ty == 0) {
dread_T_gm2sm<BASIS_P, BASIS_Q>(tx, transT, dinterp1d, sTinterp);
dread_T_gm2sm<BASIS_P, BASIS_Q>(tx, transT, dgrad1d, sTgrad);
read_T_notrans_gm2sm<BASIS_P, BASIS_Q>(tx, dinterp1d, sTinterp);
read_T_notrans_gm2sm<BASIS_P, BASIS_Q>(tx, dgrad1d, sTgrad);
}

// No need to read V ( required only in transposed grad )
const CeedScalar beta = 0.0;

/* read U (idim = 0 for dU, i_DIM = 0 for rU) --
there is a sync at the end of this function */
readU_2d<CeedScalar, BASIS_P, 1, BASIS_NUM_COMP, BASIS_P, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx);
read_U_2d<CeedScalar, BASIS_P, 1, BASIS_NUM_COMP, BASIS_P, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx);

/* first call (i_DIM = 0, i_DIM_U = 0, i_DIM_V = 0) --
output from rV[0][][] into dV (idim = 0) */
magma_grad_2d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_P, BASIS_Q, BASIS_P, BASIS_Q, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
magma_grad_2d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_P, BASIS_Q, BASIS_P, BASIS_Q, 0, 0, 0, false>(sTinterp, sTgrad, rU, rV, tx, rTmp,
sTmp);
/* there is a sync at the end of magma_grad_2d_device */
writeV_2d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
write_V_2d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);

/* second call (i_DIM = 1, i_DIM_U = 0, i_DIM_V = 0) --
output from rV[0][][] into dV (idim = 1) */
magma_grad_2d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_P, BASIS_Q, BASIS_P, BASIS_Q, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
magma_grad_2d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_P, BASIS_Q, BASIS_P, BASIS_Q, 1, 0, 0, false>(sTinterp, sTgrad, rU, rV, tx, rTmp,
sTmp);
/* there is a sync at the end of magma_grad_2d_device */
writeV_2d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dV + (1 * dstrdV), cstrdV, rV, tx);
write_V_2d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dV + (1 * dstrdV), cstrdV, rV, tx);
}

//////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_2D)) __global__
void magma_gradt_2d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU,
const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) {
MAGMA_DEVICE_SHARED(CeedScalar, shared_data)

const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int elem_id = (blockIdx.x * blockDim.y) + ty;
magma_trans_t transT = MagmaTrans;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int elem_id = (blockIdx.x * blockDim.y) + ty;

if (elem_id >= nelem) return;

Expand All @@ -158,32 +168,27 @@ extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_

// read T
if (ty == 0) {
dread_T_gm2sm<BASIS_Q, BASIS_P>(tx, transT, dinterp1d, sTinterp);
dread_T_gm2sm<BASIS_Q, BASIS_P>(tx, transT, dgrad1d, sTgrad);
read_T_trans_gm2sm<BASIS_Q, BASIS_P>(tx, dinterp1d, sTinterp);
read_T_trans_gm2sm<BASIS_Q, BASIS_P>(tx, dgrad1d, sTgrad);
}
__syncthreads();

/* read V (since this is transposed mode --
idim = 0 for dV, i_DIM = 0 for rV) */
const CeedScalar beta = 1.0;
readV_2d<CeedScalar, BASIS_P, 1, BASIS_NUM_COMP, BASIS_P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);

/* read U (idim = 0 for dU, i_DIM = 0 for rU) --
there is a sync at the end of this function */
readU_2d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx);
read_U_2d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx);
/* first call (i_DIM = 0, i_DIM_U = 0, i_DIM_V = 0) */
magma_grad_2d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_Q, BASIS_P, BASIS_Q, BASIS_P, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
magma_grad_2d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_Q, BASIS_P, BASIS_Q, BASIS_P, 0, 0, 0, true>(sTinterp, sTgrad, rU, rV, tx, rTmp, sTmp);
/* there is a sync at the end of magma_grad_2d_device */

/* read U (idim = 1 for dU, i_DIM = 0 for rU) --
there is a sync at the end of this function */
readU_2d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dU + (1 * dstrdU), cstrdU, rU, sTmp, tx);
read_U_2d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dU + (1 * dstrdU), cstrdU, rU, sTmp, tx);
/* second call (i_DIM = 1, i_DIM_U = 0, i_DIM_V = 0) */
magma_grad_2d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_Q, BASIS_P, BASIS_Q, BASIS_P, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
magma_grad_2d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_Q, BASIS_P, BASIS_Q, BASIS_P, 1, 0, 0, true>(sTinterp, sTgrad, rU, rV, tx, rTmp, sTmp);
/* there is a sync at the end of magma_grad_2d_device */

// write V
writeV_2d<CeedScalar, BASIS_P, 1, BASIS_NUM_COMP, BASIS_P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
write_V_2d<CeedScalar, BASIS_P, 1, BASIS_NUM_COMP, BASIS_P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
}

#endif // CEED_MAGMA_BASIS_GRAD_2D_H
Loading

0 comments on commit a0804ae

Please sign in to comment.