Skip to content

Commit

Permalink
speed up runGLOBALBV
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangpin committed Dec 31, 2024
1 parent 93129d2 commit 9a06b9e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
4 changes: 2 additions & 2 deletions R/methods-globalbv.R
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ setMethod("runGLOBALBV", "SVPExperiment", function(
assay.type <- 1
}

x <- assay(data, assay.type)
x <- NULL
if (is.null(features1) && is.null(features2) && is.null(gsvaexp.features)){
cli::cli_abort(c("The {.var gsvaexp} is specified, and the `data` is {.cls {class(data)}}.",
"The `features1`, `features2` and `gsvaexp.features` should not be `NULL` simultaneously."))
Expand All @@ -270,9 +270,9 @@ setMethod("runGLOBALBV", "SVPExperiment", function(
features2 <- gsvaexp.features

if (length(features1) >= 1){
x <- assay(data, assay.type)
x <- x[features1, , drop=FALSE]
}

x2 <- .extract_gsvaExp_assay(data, gsvaexp, gsvaexp.assay.type)
x2 <- x2[features2, , drop=FALSE]
x <- rbind(x, x2)
Expand Down
47 changes: 27 additions & 20 deletions src/bisp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,22 @@ arma::vec cal_global_lee_test(
return(res);
}

arma::vec cal_Pquant(sp_mat w, int n){
double tt = 0.0;
arma::vec MII(n);
arma::vec diagM(n);
arma::vec diagMt(n);
arma::vec wv(n);
for (int i = 0; i < n; i++){
for (int j = 0; j < n; j++){
wv(j) = accu(w.col(i) % w.col(j));
struct CalPquant : public Worker{
const arma::sp_mat& w;
arma::mat& res;

CalPquant(const arma::sp_mat& w, arma::mat& res):
w(w), res(res){}

void operator()(std::size_t begin, std::size_t end){
for (std::size_t i = begin; i < end; i++){
arma::vec wv = w.t() * w.col(i).as_dense();
res(i, 0) = accu(wv);
res(i, 1) = wv(i);
res(i, 2) = accu(pow(wv, 2.0));
}
tt += accu(wv);
MII(i) = accu(wv);
diagM(i) = wv(i);
diagMt(i) = accu(pow(wv, 2.0));
}
diagM = diagM/tt;
MII = MII/tt;
diagMt = diagMt/pow(tt, 2.0);
arma::vec res = cal_quant(diagM, diagMt, MII);
return(res);
}
};

struct RunGlobalLee : public Worker{
const arma::sp_mat& x;
Expand Down Expand Up @@ -99,6 +94,18 @@ struct RunGlobalLee : public Worker{
}
};

arma::vec cal_Pquant_Parallel(sp_mat w, int n){
arma::mat res(n, 3);
CalPquant runpquant(w, res);
parallelFor(0, n, runpquant);
double tt = accu(res.col(0));
arma::vec diagM = res.col(1) / tt;
arma::vec MII = res.col(0) / tt;
arma::vec diagMt = res.col(2) / pow(tt, 2.0);
arma::vec Pq = cal_quant(diagM, diagMt, MII);
return(Pq);
}

//[[Rcpp::export]]
Rcpp::List CalGlobalLeeParallel(
arma::sp_mat& x,
Expand All @@ -123,7 +130,7 @@ Rcpp::List CalGlobalLeeParallel(

arma::vec P(6);
if (permutation <= 10){
P = cal_Pquant(wm, m);
P = cal_Pquant_Parallel(wm, m);
}

simple_progress p(n1 * n2);
Expand Down

0 comments on commit 9a06b9e

Please sign in to comment.