From 9a06b9e6f6f6203c5ba1e22f1bad65f5cd506535 Mon Sep 17 00:00:00 2001 From: xushuangbin Date: Tue, 31 Dec 2024 18:12:43 +0800 Subject: [PATCH] speed up runGLOBALBV --- R/methods-globalbv.R | 4 ++-- src/bisp.cpp | 47 +++++++++++++++++++++++++------------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/R/methods-globalbv.R b/R/methods-globalbv.R index c3cac20..46c67f2 100644 --- a/R/methods-globalbv.R +++ b/R/methods-globalbv.R @@ -251,7 +251,7 @@ setMethod("runGLOBALBV", "SVPExperiment", function( assay.type <- 1 } - x <- assay(data, assay.type) + x <- NULL if (is.null(features1) && is.null(features2) && is.null(gsvaexp.features)){ cli::cli_abort(c("The {.var gsvaexp} is specified, and the `data` is {.cls {class(data)}}.", "The `features1`, `features2` and `gsvaexp.features` should not be `NULL` simultaneously.")) @@ -270,9 +270,9 @@ setMethod("runGLOBALBV", "SVPExperiment", function( features2 <- gsvaexp.features if (length(features1) >= 1){ + x <- assay(data, assay.type) x <- x[features1, , drop=FALSE] } - x2 <- .extract_gsvaExp_assay(data, gsvaexp, gsvaexp.assay.type) x2 <- x2[features2, , drop=FALSE] x <- rbind(x, x2) diff --git a/src/bisp.cpp b/src/bisp.cpp index 5b822a0..363dfdd 100644 --- a/src/bisp.cpp +++ b/src/bisp.cpp @@ -29,27 +29,22 @@ arma::vec cal_global_lee_test( return(res); } -arma::vec cal_Pquant(sp_mat w, int n){ - double tt = 0.0; - arma::vec MII(n); - arma::vec diagM(n); - arma::vec diagMt(n); - arma::vec wv(n); - for (int i = 0; i < n; i++){ - for (int j = 0; j < n; j++){ - wv(j) = accu(w.col(i) % w.col(j)); +struct CalPquant : public Worker{ + const arma::sp_mat& w; + arma::mat& res; + + CalPquant(const arma::sp_mat& w, arma::mat& res): + w(w), res(res){} + + void operator()(std::size_t begin, std::size_t end){ + for (std::size_t i = begin; i < end; i++){ + arma::vec wv = w.t() * w.col(i).as_dense(); + res(i, 0) = accu(wv); + res(i, 1) = wv(i); + res(i, 2) = accu(pow(wv, 2.0)); } - tt += accu(wv); - MII(i) = accu(wv); - diagM(i) = wv(i); - diagMt(i) = accu(pow(wv, 2.0)); } - diagM = diagM/tt; - MII = MII/tt; - diagMt = diagMt/pow(tt, 2.0); - arma::vec res = cal_quant(diagM, diagMt, MII); - return(res); -} +}; struct RunGlobalLee : public Worker{ const arma::sp_mat& x; @@ -99,6 +94,18 @@ struct RunGlobalLee : public Worker{ } }; +arma::vec cal_Pquant_Parallel(sp_mat w, int n){ + arma::mat res(n, 3); + CalPquant runpquant(w, res); + parallelFor(0, n, runpquant); + double tt = accu(res.col(0)); + arma::vec diagM = res.col(1) / tt; + arma::vec MII = res.col(0) / tt; + arma::vec diagMt = res.col(2) / pow(tt, 2.0); + arma::vec Pq = cal_quant(diagM, diagMt, MII); + return(Pq); +} + //[[Rcpp::export]] Rcpp::List CalGlobalLeeParallel( arma::sp_mat& x, @@ -123,7 +130,7 @@ Rcpp::List CalGlobalLeeParallel( arma::vec P(6); if (permutation <= 10){ - P = cal_Pquant(wm, m); + P = cal_Pquant_Parallel(wm, m); } simple_progress p(n1 * n2);