Skip to content

Commit

Permalink
updated based on Jeremy's comment
Browse files Browse the repository at this point in the history
  • Loading branch information
rezgarshakeri committed Sep 13, 2023
1 parent 5ba8f02 commit bb367a2
Showing 1 changed file with 36 additions and 33 deletions.
69 changes: 36 additions & 33 deletions interface/ceed-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -318,26 +318,21 @@ static int CeedBasisCreateProjectionMatrices(CeedBasis basis_from, CeedBasis bas
**/
int CeedBasisGetCollocatedGrad(CeedBasis basis, CeedScalar *collo_grad_1d) {
Ceed ceed;
CeedInt P_1d = (basis)->P_1d, Q_1d = (basis)->Q_1d;
CeedScalar *interp_1d, *interp_1d_pinv, *grad_1d, *tau;

CeedCall(CeedMalloc(Q_1d * P_1d, &interp_1d));
CeedCall(CeedMalloc(P_1d * Q_1d, &interp_1d_pinv));
CeedCall(CeedMalloc(Q_1d * P_1d, &grad_1d));
CeedCall(CeedMalloc(Q_1d, &tau));
memcpy(interp_1d, (basis)->interp_1d, Q_1d * P_1d * sizeof(basis)->interp_1d[0]);
memcpy(grad_1d, (basis)->grad_1d, Q_1d * P_1d * sizeof(basis)->interp_1d[0]);
CeedInt P_1d, Q_1d;
CeedScalar *interp_1d_pinv;

// QR Factorization, interp_1d = Q R
CeedCall(CeedBasisGetCeed(basis, &ceed));
CeedCall(CeedBasisGetNumNodes1D(basis, &P_1d));
CeedCall(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));

CeedCall(CeedMatrixPseudoinverse(ceed, interp_1d, Q_1d, P_1d, interp_1d_pinv));
CeedCall(CeedMatrixMatrixMultiply(ceed, (const CeedScalar *)interp_1d_pinv, (const CeedScalar *)grad_1d, collo_grad_1d, Q_1d, P_1d, P_1d));
// QR Factorization, interp_1d = Q R
CeedCall(CeedCalloc(P_1d * Q_1d, &interp_1d_pinv));

CeedCall(CeedMatrixPseudoinverse(ceed, basis->interp_1d, Q_1d, P_1d, interp_1d_pinv));
CeedCall(CeedMatrixMatrixMultiply(ceed, basis->grad_1d, (const CeedScalar *)interp_1d_pinv, collo_grad_1d, Q_1d, Q_1d, P_1d));

CeedCall(CeedFree(&interp_1d));
CeedCall(CeedFree(&interp_1d_pinv));
CeedCall(CeedFree(&grad_1d));
CeedCall(CeedFree(&tau));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -691,37 +686,45 @@ int CeedHouseholderApplyQ(CeedScalar *mat_A, const CeedScalar *mat_Q, const Ceed
/**
@brief Return pseudoinverse of a matrix
@param[in] ceed Ceed context for error handling
@param[in] mat Row-major matrix to be factorized in place
@param[in] m Number of rows
@param[in] n Number of columns
@param[in] ceed Ceed context for error handling
@param[in] mat Row-major matrix to be factorized in place
@param[in] m Number of rows
@param[in] n Number of columns
@param[out] mat_pinv Row-major pseudoinverse matrix
@return An error code: 0 - success, otherwise - failure
@ref Utility
**/
int CeedMatrixPseudoinverse(Ceed ceed, CeedScalar *mat, CeedInt m, CeedInt n, CeedScalar *mat_pinv) {
CeedScalar *tau, *I;
CeedScalar *tau, *I, *mat_copy;

CeedCall(CeedCalloc(m * m, &I));
CeedCall(CeedCalloc(m, &tau));
// -- QR Factorization, mat = Q R
CeedCall(CeedQRFactorization(ceed, mat, tau, m, n));
// -- mat_pinv = R_inv Q^T
for (CeedInt i = 0; i < m; i++) I[i * m + i] = 1.0;
// ---- Apply R_inv, mat_pinv = I R_inv
for (CeedInt i = 0; i < m; i++) { // Row i
mat_pinv[m * i] = I[n * i] / mat[0];
for (CeedInt j = 1; j < n; j++) { // Column j
mat_pinv[j + m * i] = I[j + n * i];
for (CeedInt k = 0; k < j; k++) mat_pinv[j + m * i] -= mat[j + n * k] * mat_pinv[k + m * i];
mat_pinv[j + m * i] /= mat[j + n * j];
CeedCall(CeedCalloc(n * n, &I));
CeedCall(CeedCalloc(m * n, &mat_copy));
memcpy(mat_copy, mat, m * n * sizeof mat[0]);

// QR Factorization, mat = Q R
CeedCall(CeedQRFactorization(ceed, mat_copy, tau, m, n));

// mat_pinv = R_inv Q^T
for (CeedInt i = 0; i < n; i++) I[i * n + i] = 1.0;
// -- Apply R_inv, mat_pinv = I R_inv
for (CeedInt i = 0; i < n; i++) { // Row i
mat_pinv[n * i] = I[n * i] / mat_copy[0];
for (CeedInt j = 1; j < m; j++) { // Column j
mat_pinv[j + n * i] = I[j + n * i];
for (CeedInt k = 0; k < j; k++) mat_pinv[j + n * i] -= mat_copy[j + m * k] * mat_pinv[k + n * i];
mat_pinv[j + n * i] /= mat_copy[j + n * j];
}
}
// ---- Apply Q^T, mat_pinv = R_inv Q^T
CeedCall(CeedHouseholderApplyQ(mat, mat_pinv, tau, CEED_NOTRANSPOSE, m, n, n, 1, n));
// -- Apply Q^T, mat_pinv = R_inv Q^T
CeedCall(CeedHouseholderApplyQ(mat_pinv, mat_copy, tau, CEED_NOTRANSPOSE, n, m, n, 1, n));

// Cleanup
CeedCall(CeedFree(&I));
CeedCall(CeedFree(&tau));
CeedCall(CeedFree(&mat_copy));
return CEED_ERROR_SUCCESS;
}

Expand Down

0 comments on commit bb367a2

Please sign in to comment.