diff --git a/R/step-subset-expand.R b/R/step-subset-expand.R index fefa753fa..7efbae270 100644 --- a/R/step-subset-expand.R +++ b/R/step-subset-expand.R @@ -38,12 +38,101 @@ #' fruits %>% dplyr::right_join(all) # exported onLoad expand.dtplyr_step <- function(data, ..., .name_repair = "check_unique") { - dots <- capture_dots(data, ..., .j = FALSE) - dots <- dots[!vapply(dots, is_null, logical(1))] + dots <- prepare_expand_dots(data, ..., .name_repair = .name_repair) + + # TODO handle factors if (length(dots) == 0) { return(data) } + tbl_list <- c( + list(expand_no_nesting(data, dots$simple)), + expand_nesting(data, dots$nesting) + ) + + out <- Reduce(function(x, y) left_join(x, y, by = group_vars(data)), tbl_list) + + renamed <- names(dots$select) != unname(dots$select) + relocated <- unname(dots$select) != out$vars + if (any(renamed) || any(relocated)) { + out <- select(out, !!!dots$select) + } + + out +} + +# exported onLoad +expand.data.table <- function(data, ..., .name_repair = "check_unique") { + data <- lazy_dt(data) + tidyr::expand(data, ..., .name_repair = .name_repair) +} + +prepare_expand_dots <- function(data, ..., .name_repair) { + dots <- capture_dots(data, ..., .j = FALSE) + + dot_is_null <- vapply(dots, is_null, logical(1)) + dots <- dots[!dot_is_null] + dot_names_tidyr <- names(exprs(..., .named = TRUE))[!dot_is_null] + if (is_null(dots)) { + return(NULL) + } + + is_nesting <- vapply(dots, function(x) is_call(x, "nesting"), logical(1)) + dots_df <- tibble::tibble( + expr = dots, + position = seq_along(dots) + ) + + dots_df_nesting <- dots_df[is_nesting, ] + nesting_vars <- lapply(dots_df_nesting$expr, get_nesting_vars) + dots_df_nesting$name_tidyr <- lapply(nesting_vars, names) + dots_df_nesting$var <- lapply(nesting_vars, unlist) + + dots_df_simple <- dots_df[!is_nesting, ] + simple_vars <- dt_dot_names(dots_df_simple$expr) + dots_df_simple$name_dt <- names(simple_vars) + dots_df_simple$var <- simple_vars + dots_df_simple$name_tidyr <- dot_names_tidyr[!is_nesting] + + meta_df <- dplyr::bind_rows( + dots_df_simple, + tidyr::unnest(dots_df_nesting, "name_tidyr") + ) + groups <- group_vars(data) + names_dt <- c(groups, dplyr::coalesce(meta_df$name_dt, meta_df$name_tidyr)) + names_tidyr <- vctrs::vec_as_names( + c(groups, meta_df$name_tidyr), + repair = .name_repair + ) + order <- c(seq_along(groups), length(groups) + order(meta_df$position)) + + list( + simple = dots_df_simple$var, + nesting = dots_df_nesting$var, + select = set_names(names_dt, names_tidyr)[order] + ) +} + +get_nesting_vars <- function(expr) { + args <- call_args(expr) + + repair <- args[[".name_repair"]] %||% "check_unique" + args[[".name_repair"]] <- NULL + + vars <- exprs_auto_name(args) + nms <- vctrs::vec_as_names(names(vars), repair = repair) + set_names(vars, nms) +} + +expand_nesting <- function(data, vars) { + if (is_empty(vars)) { + return(NULL) + } + + lapply(vars, function(x) distinct(data, !!!x)) +} + +dt_dot_names <- function(dots, .name_repair) { named_dots <- have_name(dots) if (any(!named_dots)) { # Auto-names generated by enquos() don't always work with the CJ() step @@ -55,24 +144,31 @@ expand.dtplyr_step <- function(data, ..., .name_repair = "check_unique") { names(dots)[needs_v_name] <- v_names[needs_v_name] names(dots)[symbol_dots] <- lapply(dots[symbol_dots], as_name) } - names(dots) <- vctrs::vec_as_names(names(dots), repair = .name_repair) - - on <- names(dots) - cj <- expr(CJ(!!!syms(on), unique = TRUE)) + dots +} - out <- distinct(data, !!!syms(data$groups), !!!dots) +expand_no_nesting <- function(data, dots, .name_repair) { if (length(data$groups) == 0) { - out <- step_subset(out, i = cj, on = on) + dt_vars <- names(dots) + + dt_auto_names <- names(dt_dot_names(unname(dots))) + name_needed <- dt_auto_names != dt_vars + names(dots)[!name_needed] <- "" + + out <- step_subset_j( + parent = data, + vars = dt_vars, + j = expr(CJ(!!!dots, unique = TRUE)) + ) } else { + out <- distinct(data, !!!syms(data$groups), !!!dots) + + on <- names(dots) + cj <- expr(CJ(!!!syms(on), unique = TRUE)) + on <- call2(".", !!!syms(on)) out <- step_subset(out, j = expr(.SD[!!cj, on = !!on])) } out } - -# exported onLoad -expand.data.table <- function(data, ..., .name_repair = "check_unique") { - data <- lazy_dt(data) - tidyr::expand(data, ..., .name_repair = .name_repair) -} diff --git a/tests/testthat/test-step-subset-expand.R b/tests/testthat/test-step-subset-expand.R index 03133531b..53484e2c4 100644 --- a/tests/testthat/test-step-subset-expand.R +++ b/tests/testthat/test-step-subset-expand.R @@ -6,7 +6,7 @@ test_that("expand completes all values", { expect_equal( show_query(step), - expr(unique(DT)[CJ(x, y, unique = TRUE), on = .(x, y)]) + expr(DT[, CJ(x, y, unique = TRUE)]) ) expect_equal(step$vars, c("x", "y")) expect_equal(nrow(out), 4) @@ -29,9 +29,10 @@ test_that("works with unnamed vectors", { expect_equal( show_query(step), - expr(unique(DT[, .(x = x, V2 = 1:2)])[CJ(x, V2, unique = TRUE), on = .(x, V2)]) + # expr(unique(DT[, .(x = x, V2 = 1:2)])[CJ(x, V2, unique = TRUE), on = .(x, V2)]) + expr(DT[, CJ(x, 1:2, unique = TRUE)][, .(x, `1:2` = V2)]) ) - expect_equal(step$vars, c("x", "V2")) + expect_equal(step$vars, c("x", "1:2")) expect_equal(nrow(out), 4) }) @@ -43,7 +44,8 @@ test_that("works with named vectors", { expect_equal( show_query(step), - expr(unique(DT[, .(x = x, val = 1:2)])[CJ(x, val, unique = TRUE), on = .(x, val)]) + # expr(unique(DT[, .(x = x, val = 1:2)])[CJ(x, val, unique = TRUE), on = .(x, val)]) + expr(DT[, CJ(x, val = 1:2, unique = TRUE)]) ) expect_equal(step$vars, c("x", "val")) expect_equal(nrow(out), 4)