Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample only for feature extraction #1130

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Imports:
SqlRender (>= 1.9.0),
stringr,
tidyr (>= 1.2.0),
CohortGenerator (>= 0.8.0),
CohortGenerator (>= 0.10.0),
remotes,
scales
Suggests:
Expand All @@ -61,7 +61,7 @@ License: Apache License
VignetteBuilder: knitr
URL: https://ohdsi.github.io/CohortDiagnostics, https://github.com/OHDSI/CohortDiagnostics
BugReports: https://github.com/OHDSI/CohortDiagnostics/issues
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
Encoding: UTF-8
Language: en-US
StagedInstall: no
Expand Down
19 changes: 11 additions & 8 deletions R/CohortLevelDiagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ getCohortCounts <- function(connectionDetails = NULL,
)
counts <-
DatabaseConnector::querySql(connection, sql, snakeCaseToCamelCase = TRUE) %>%
tidyr::tibble()
tidyr::tibble()

if (length(cohortIds) > 0) {
cohortIdDf <- tidyr::tibble(cohortId = as.numeric(cohortIds))
Expand Down Expand Up @@ -97,7 +97,8 @@ computeCohortCounts <- function(connection,
cohorts,
exportFolder,
minCellCount,
databaseId) {
databaseId,
writeResult = TRUE) {
ParallelLogger::logInfo("Counting cohort records and subjects")
cohortCounts <- getCohortCounts(
connection = connection,
Expand All @@ -117,11 +118,13 @@ computeCohortCounts <- function(connection,
databaseId = databaseId
)

writeToCsv(
data = cohortCounts,
fileName = file.path(exportFolder, "cohort_count.csv"),
incremental = FALSE,
cohortId = cohorts$cohortId
)
if (writeResult) {
writeToCsv(
data = cohortCounts,
fileName = file.path(exportFolder, "cohort_count.csv"),
incremental = FALSE,
cohortId = cohorts$cohortId
)
}
return(cohortCounts)
}
3 changes: 2 additions & 1 deletion R/Incremental.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ writeToCsv <- function(data, fileName, incremental = FALSE, ...) {
UseMethod("writeToCsv", data)
}


#' @noRd
writeToCsv.default <- function(data, fileName, incremental = FALSE, ...) {
colnames(data) <- SqlRender::camelCaseToSnakeCase(colnames(data))
if (incremental) {
Expand Down Expand Up @@ -186,6 +186,7 @@ writeToCsv.default <- function(data, fileName, incremental = FALSE, ...) {
}
}

#'@noRd
writeToCsv.tbl_Andromeda <-
function(data, fileName, incremental = FALSE, ...) {
if (incremental && file.exists(fileName)) {
Expand Down
3 changes: 2 additions & 1 deletion R/Private.R
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,9 @@ getPrefixedTableNames <- function(tablePrefix) {
return(resultList)
}

#' @noRd

#' Internal utility function for logging execution of variables
#' @noRd
timeExecution <- function(exportFolder,
taskName,
cohortIds = NULL,
Expand Down
132 changes: 75 additions & 57 deletions R/RunDiagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,18 @@ getDefaultCovariateSettings <- function() {
#' @param incremental Create only cohort diagnostics that haven't been created before?
#' @param incrementalFolder If \code{incremental = TRUE}, specify a folder where records are kept
#' of which cohort diagnostics has been executed.
#' @param runOnSample Logical. If TRUE, the function will operate on a sample of the data.
#' @param runFeatureExtractionOnSample Logical. If TRUE, the function will operate on a sample of the data.
#' Default is FALSE, meaning the function will operate on the full data set.
#'
#' @param sampleN Integer. The number of records to include in the sample if runOnSample is TRUE.
#' Default is 1000. Ignored if runOnSample is FALSE.
#' @param sampleN Integer. The number of records to include in the sample if runFeatureExtractionOnSample is TRUE.
#' Default is 1000. Ignored if runFeatureExtractionOnSample is FALSE.
#'
#' @param seed Integer. The seed for the random number generator used to create the sample.
#' This ensures that the same sample can be drawn again in future runs. Default is 64374.
#'
#' @param seedArgs List. Additional arguments to pass to the sampling function.
#' This can be used to control aspects of the sampling process beyond the seed and sample size.
#'
#' @param sampleIdentifierExpression Character. An expression that generates unique identifiers for each sample.
#' This expression can use the variables 'cohortId' and 'seed'.
#' Default is "cohortId * 1000 + seed", which ensures unique identifiers
#' as long as there are fewer than 1000 cohorts.

#' @examples
#' \dontrun{
#' # Load cohorts (assumes that they have already been instantiated)
Expand Down Expand Up @@ -228,11 +223,10 @@ executeDiagnostics <- function(cohortDefinitionSet,
irWashoutPeriod = 0,
incremental = FALSE,
incrementalFolder = file.path(exportFolder, "incremental"),
runOnSample = FALSE,
runFeatureExtractionOnSample = FALSE,
sampleN = 1000,
seed = 64374,
seedArgs = NULL,
sampleIdentifierExpression = "cohortId * 1000 + seed") {
seedArgs = NULL) {
# collect arguments that were passed to cohort diagnostics at initiation
callingArgs <- formals(executeDiagnostics)
callingArgsJson <-
Expand All @@ -250,7 +244,7 @@ executeDiagnostics <- function(cohortDefinitionSet,
incremental = callingArgs$incremental,
temporalCovariateSettings = callingArgs$temporalCovariateSettings
) %>%
RJSONIO::toJSON(digits = 23, pretty = TRUE)
RJSONIO::toJSON(digits = 23, pretty = TRUE)

exportFolder <- normalizePath(exportFolder, mustWork = FALSE)
incrementalFolder <- normalizePath(incrementalFolder, mustWork = FALSE)
Expand Down Expand Up @@ -279,25 +273,25 @@ executeDiagnostics <- function(cohortDefinitionSet,
errorMessage <- checkmate::makeAssertCollection()
checkmate::assertList(cohortTableNames, null.ok = FALSE, types = "character", add = errorMessage, names = "named")
checkmate::assertNames(names(cohortTableNames),
must.include = c(
"cohortTable",
"cohortInclusionTable",
"cohortInclusionResultTable",
"cohortInclusionStatsTable",
"cohortSummaryStatsTable",
"cohortCensorStatsTable"
),
add = errorMessage
must.include = c(
"cohortTable",
"cohortInclusionTable",
"cohortInclusionResultTable",
"cohortInclusionStatsTable",
"cohortSummaryStatsTable",
"cohortCensorStatsTable"
),
add = errorMessage
)
checkmate::assertDataFrame(cohortDefinitionSet, add = errorMessage)
checkmate::assertNames(names(cohortDefinitionSet),
must.include = c(
"json",
"cohortId",
"cohortName",
"sql"
),
add = errorMessage
must.include = c(
"json",
"cohortId",
"cohortName",
"sql"
),
add = errorMessage
)

cohortTable <- cohortTableNames$cohortTable
Expand Down Expand Up @@ -474,17 +468,17 @@ executeDiagnostics <- function(cohortDefinitionSet,
sort()
cohortTableColumnNamesExpected <-
getResultsDataModelSpecifications() %>%
dplyr::filter(.data$tableName == "cohort") %>%
dplyr::pull(.data$columnName) %>%
SqlRender::snakeCaseToCamelCase() %>%
sort()
dplyr::filter(.data$tableName == "cohort") %>%
dplyr::pull(.data$columnName) %>%
SqlRender::snakeCaseToCamelCase() %>%
sort()
cohortTableColumnNamesRequired <-
getResultsDataModelSpecifications() %>%
dplyr::filter(.data$tableName == "cohort") %>%
dplyr::filter(.data$isRequired == "Yes") %>%
dplyr::pull(.data$columnName) %>%
SqlRender::snakeCaseToCamelCase() %>%
sort()
dplyr::filter(.data$tableName == "cohort") %>%
dplyr::filter(.data$isRequired == "Yes") %>%
dplyr::pull(.data$columnName) %>%
SqlRender::snakeCaseToCamelCase() %>%
sort()

expectedButNotObsevered <-
setdiff(x = cohortTableColumnNamesExpected, y = cohortTableColumnNamesObserved)
Expand Down Expand Up @@ -549,23 +543,6 @@ executeDiagnostics <- function(cohortDefinitionSet,
}
}

if (runOnSample & !isTRUE(attr(cohortDefinitionSet, "isSampledCohortDefinition"))) {
cohortDefinitionSet <-
CohortGenerator::sampleCohortDefinitionSet(
connection = connection,
cohortDefinitionSet = cohortDefinitionSet,
tempEmulationSchema = tempEmulationSchema,
cohortDatabaseSchema = cohortDatabaseSchema,
cohortTableNames = cohortTableNames,
n = sampleN,
seed = seed,
seedArgs = seedArgs,
identifierExpression = sampleIdentifierExpression,
incremental = incremental,
incrementalFolder = incrementalFolder
)
}

## CDM source information----
timeExecution(
exportFolder,
Expand Down Expand Up @@ -871,18 +848,59 @@ executeDiagnostics <- function(cohortDefinitionSet,
cohortIds,
parent = "executeDiagnostics",
expr = {

feCohortDefinitionSet <- cohortDefinitionSet
feCohortTable <- cohortTable
feCohortCounts <- cohortCounts

if (runFeatureExtractionOnSample & !isTRUE(attr(cohortDefinitionSet, "isSampledCohortDefinition"))) {
cohortTableNames$cohortSampleTable <- paste0(cohortTableNames$cohortTable, "_cd_sample")
CohortGenerator::createCohortTables(connection = connection,
cohortTableNames = cohortTableNames,
cohortDatabaseSchema = cohortDatabaseSchema,
incremental = TRUE)

feCohortTable <- cohortTableNames$cohortSampleTable
feCohortDefinitionSet <-
CohortGenerator::sampleCohortDefinitionSet(
connection = connection,
cohortDefinitionSet = cohortDefinitionSet,
tempEmulationSchema = tempEmulationSchema,
cohortDatabaseSchema = cohortDatabaseSchema,
cohortTableNames = cohortTableNames,
n = sampleN,
seed = seed,
seedArgs = seedArgs,
identifierExpression = "cohortId",
incremental = incremental,
incrementalFolder = incrementalFolder
)

feCohortCounts <- computeCohortCounts(
connection = connection,
cohortDatabaseSchema = cohortDatabaseSchema,
cohortTable = cohortTableNames$cohortSampleTable,
cohorts = feCohortDefinitionSet,
exportFolder = exportFolder,
minCellCount = minCellCount,
databaseId = databaseId,
writeResult = FALSE
)
}


executeCohortCharacterization(
connection = connection,
databaseId = databaseId,
exportFolder = exportFolder,
cdmDatabaseSchema = cdmDatabaseSchema,
cohortDatabaseSchema = cohortDatabaseSchema,
cohortTable = cohortTable,
cohortTable = feCohortTable,
covariateSettings = temporalCovariateSettings,
tempEmulationSchema = tempEmulationSchema,
cdmVersion = cdmVersion,
cohorts = cohortDefinitionSet,
cohortCounts = cohortCounts,
cohorts = feCohortDefinitionSet,
cohortCounts = feCohortCounts,
minCellCount = minCellCount,
instantiatedCohorts = instantiatedCohorts,
incremental = incremental,
Expand Down
16 changes: 5 additions & 11 deletions man/executeDiagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-1-ResultsDataModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ VALUES ('Synthea','Synthea','OHDSI Community','SyntheaTM is a Synthetic Patient
incremental = TRUE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
},
"CDM Source table has more than one record while only one is expected."
Expand All @@ -149,7 +149,7 @@ VALUES ('Synthea','Synthea','OHDSI Community','SyntheaTM is a Synthetic Patient
incremental = TRUE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
}

Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-2-againstCdm.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ test_that("Cohort diagnostics in incremental mode", {
incremental = TRUE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
)

Expand Down Expand Up @@ -76,7 +76,7 @@ test_that("Cohort diagnostics in incremental mode", {
incremental = TRUE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
)
# generate sqlite file
Expand Down Expand Up @@ -123,7 +123,7 @@ test_that("Cohort diagnostics in incremental mode", {
incremental = FALSE,
incrementalFolder = file.path(folder, "incremental"),
temporalCovariateSettings = temporalCovariateSettings,
runOnSample = TRUE
runFeatureExtractionOnSample = TRUE
)
})

Expand Down
Loading