Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Order of feature in formula seems to matter (last place, particularly) #733

Closed
tomauer opened this issue Jul 25, 2024 · 6 comments · Fixed by #738
Closed

Order of feature in formula seems to matter (last place, particularly) #733

tomauer opened this issue Jul 25, 2024 · 6 comments · Fixed by #738

Comments

@tomauer
Copy link

tomauer commented Jul 25, 2024

Seeing some odd behavior where a feature appears to not be considered depending on where it is in the formula definition. If it's last, it ends up with a variable importance of 0, but if it's first, it has a high variable importance. This is a feature that we know will always have a high variable importance. Increasing mtry does seem to bring the feature into consideration, which would suggest it has something to do with how it's randomly choosing features, but really, order in the formula definition shouldn't matter, right?

I've provided all the parts for a reproducible example here, with RDS of input data attached. predicted_er is the feature of interest here. Note that when it goes first, then the new "last" feature now has an importance of 0, after having a high importance when it wasn't last.

library(ranger) # 0.16.0 or development branch
library(dplyr)

data <- readRDS("~/Downloads/example_data.rds")

always_split <- c("island", "mountain", "has_shoreline", "has_evi",
                  "has_ocean_chlorophyll", "has_ocean_sst", "has_ocean_ssta")

# predicted_er is the variable of interest
formula_last <- as.formula("count ~ is_stationary + year + solar_noon_diff + effort_hrs +
  effort_distance_km + effort_speed_kmph + num_observers +
  cci + moon_fraction + moon_altitude + cds_u10 + cds_v10 +
  cds_d2m + cds_t2m + cds_hcc + cds_i10fg + cds_mcc + cds_lcc +
  cds_sf + cds_rf + cds_slc + cds_msl + island + mountain +
  eastness_1km_median + eastness_1km_sd + eastness_90m_median +
  eastness_90m_sd + northness_1km_median + northness_1km_sd +
  northness_90m_median + northness_90m_sd + elevation_30m_median +
  elevation_30m_sd + bathymetry_elevation_median + bathymetry_elevation_sd +
  astwbd_c1_ed + astwbd_c1_pland + astwbd_c2_ed + astwbd_c2_pland +
  astwbd_c3_ed + astwbd_c3_pland + gsw_c2_pland + gsw_c2_ed +
  ntl_mean + ntl_sd + road_density_c1 + road_density_c2 + road_density_c3 +
  road_density_c4 + road_density_c5 + mcd12q1_lccs1_c1_ed +
  mcd12q1_lccs1_c1_pland + mcd12q1_lccs1_c11_ed + mcd12q1_lccs1_c11_pland +
  mcd12q1_lccs1_c13_ed + mcd12q1_lccs1_c13_pland + mcd12q1_lccs1_c14_ed +
  mcd12q1_lccs1_c14_pland + mcd12q1_lccs1_c15_ed + mcd12q1_lccs1_c15_pland +
  mcd12q1_lccs1_c16_ed + mcd12q1_lccs1_c16_pland + mcd12q1_lccs1_c21_ed +
  mcd12q1_lccs1_c21_pland + mcd12q1_lccs1_c22_ed + mcd12q1_lccs1_c22_pland +
  mcd12q1_lccs1_c31_ed + mcd12q1_lccs1_c31_pland + mcd12q1_lccs1_c32_ed +
  mcd12q1_lccs1_c32_pland + mcd12q1_lccs1_c41_ed + mcd12q1_lccs1_c41_pland +
  mcd12q1_lccs2_c25_ed + mcd12q1_lccs2_c25_pland + mcd12q1_lccs2_c36_ed +
  mcd12q1_lccs2_c36_pland + mcd12q1_lccs3_c27_ed + mcd12q1_lccs3_c27_pland +
  mcd12q1_lccs3_c50_ed + mcd12q1_lccs3_c50_pland + has_shoreline +
  shoreline_waveheight_mean + shoreline_waveheight_sd + shoreline_tidal_range_mean +
  shoreline_tidal_range_sd + shoreline_chlorophyll_mean + shoreline_chlorophyll_sd +
  shoreline_turbidity_mean + shoreline_turbidity_sd + shoreline_sinuosity_mean +
  shoreline_sinuosity_sd + shoreline_slope_mean + shoreline_slope_sd +
  shoreline_outflow_density_mean + shoreline_outflow_density_sd +
  shoreline_erodibility_n + shoreline_erodibility_c2_density +
  shoreline_erodibility_c4_density + shoreline_emu_physical_n +
  shoreline_emu_physical_c1_density + shoreline_emu_physical_c4_density +
  shoreline_emu_physical_c14_density + has_ocean_chlorophyll +
  ocean_chlorophyll + has_ocean_sst + ocean_sst + has_ocean_ssta +
  ocean_ssta + has_evi + evi_median + evi_sd + circular_day_of_year_sin +
  circular_day_of_year_cos + predicted_er")

# regression forest for count
rf_count <- ranger::ranger(
  formula = formula_last,
  num.trees = 150,
  importance = "impurity",
  num.threads = 1,
  respect.unordered.factors = "order",
  always.split.variables = always_split,
  data = data
)

data.frame(
  response = "count",
  predictor = names(rf_count$variable.importance),
  importance = rf_count$variable.importance,
  stringsAsFactors = FALSE
) %>% filter(predictor %in% c("predicted_er", "circular_day_of_year_cos"))

# my results
#                         response                predictor importance
#circular_day_of_year_cos    count circular_day_of_year_cos   570491.3
#predicted_er                count             predicted_er        0.0

formula_first <- as.formula("count ~ predicted_er + is_stationary + year + solar_noon_diff + effort_hrs +
  effort_distance_km + effort_speed_kmph + num_observers +
  cci + moon_fraction + moon_altitude + cds_u10 + cds_v10 +
  cds_d2m + cds_t2m + cds_hcc + cds_i10fg + cds_mcc + cds_lcc +
  cds_sf + cds_rf + cds_slc + cds_msl + island + mountain +
  eastness_1km_median + eastness_1km_sd + eastness_90m_median +
  eastness_90m_sd + northness_1km_median + northness_1km_sd +
  northness_90m_median + northness_90m_sd + elevation_30m_median +
  elevation_30m_sd + bathymetry_elevation_median + bathymetry_elevation_sd +
  astwbd_c1_ed + astwbd_c1_pland + astwbd_c2_ed + astwbd_c2_pland +
  astwbd_c3_ed + astwbd_c3_pland + gsw_c2_pland + gsw_c2_ed +
  ntl_mean + ntl_sd + road_density_c1 + road_density_c2 + road_density_c3 +
  road_density_c4 + road_density_c5 + mcd12q1_lccs1_c1_ed +
  mcd12q1_lccs1_c1_pland + mcd12q1_lccs1_c11_ed + mcd12q1_lccs1_c11_pland +
  mcd12q1_lccs1_c13_ed + mcd12q1_lccs1_c13_pland + mcd12q1_lccs1_c14_ed +
  mcd12q1_lccs1_c14_pland + mcd12q1_lccs1_c15_ed + mcd12q1_lccs1_c15_pland +
  mcd12q1_lccs1_c16_ed + mcd12q1_lccs1_c16_pland + mcd12q1_lccs1_c21_ed +
  mcd12q1_lccs1_c21_pland + mcd12q1_lccs1_c22_ed + mcd12q1_lccs1_c22_pland +
  mcd12q1_lccs1_c31_ed + mcd12q1_lccs1_c31_pland + mcd12q1_lccs1_c32_ed +
  mcd12q1_lccs1_c32_pland + mcd12q1_lccs1_c41_ed + mcd12q1_lccs1_c41_pland +
  mcd12q1_lccs2_c25_ed + mcd12q1_lccs2_c25_pland + mcd12q1_lccs2_c36_ed +
  mcd12q1_lccs2_c36_pland + mcd12q1_lccs3_c27_ed + mcd12q1_lccs3_c27_pland +
  mcd12q1_lccs3_c50_ed + mcd12q1_lccs3_c50_pland + has_shoreline +
  shoreline_waveheight_mean + shoreline_waveheight_sd + shoreline_tidal_range_mean +
  shoreline_tidal_range_sd + shoreline_chlorophyll_mean + shoreline_chlorophyll_sd +
  shoreline_turbidity_mean + shoreline_turbidity_sd + shoreline_sinuosity_mean +
  shoreline_sinuosity_sd + shoreline_slope_mean + shoreline_slope_sd +
  shoreline_outflow_density_mean + shoreline_outflow_density_sd +
  shoreline_erodibility_n + shoreline_erodibility_c2_density +
  shoreline_erodibility_c4_density + shoreline_emu_physical_n +
  shoreline_emu_physical_c1_density + shoreline_emu_physical_c4_density +
  shoreline_emu_physical_c14_density + has_ocean_chlorophyll +
  ocean_chlorophyll + has_ocean_sst + ocean_sst + has_ocean_ssta +
  ocean_ssta + has_evi + evi_median + evi_sd + circular_day_of_year_sin +
  circular_day_of_year_cos")

# regression forest for count
rf_count_first <- ranger::ranger(
  formula = formula_first,
  num.trees = 150,
  importance = "impurity",
  num.threads = 1,
  respect.unordered.factors = "order",
  always.split.variables = always_split,
  data = data
)

data.frame(
  response = "count",
  predictor = names(rf_count_first$variable.importance),
  importance = rf_count_first$variable.importance,
  stringsAsFactors = FALSE
) %>% filter(predictor %in% c("predicted_er", "circular_day_of_year_cos"))

# my results
#                         response                predictor importance
# predicted_er                count             predicted_er    6675684
# circular_day_of_year_cos    count circular_day_of_year_cos          0

# increasing mtry seems to help

# regression forest for count
rf_count_mtry <- ranger::ranger(
  formula = formula_last,
  num.trees = 150,
  importance = "impurity",
  num.threads = 1,
  respect.unordered.factors = "order",
  always.split.variables = always_split,
  data = data,
  mtry = round(sqrt(115)) * 2
)

data.frame(
  response = "count",
  predictor = names(rf_count_mtry$variable.importance),
  importance = rf_count_mtry$variable.importance,
  stringsAsFactors = FALSE
) %>% filter(predictor %in% c("predicted_er", "circular_day_of_year_cos"))

# my results
#                         response                predictor importance
# circular_day_of_year_cos    count circular_day_of_year_cos   630308.8
# predicted_er                count             predicted_er  9033731.5

example_data.rds.zip

@cmcrowley
Copy link

cmcrowley commented Jul 26, 2024

Problem doesn't occur when mtry > p/10 or when always.split.variables=NULL, suggestive of perhaps having to do with the index skipping in drawWithoutReplacementSimple?

@sligocki
Copy link
Contributor

sligocki commented Jul 29, 2024

Not sure if this is exactly the issue that is being hit above, but it appears that there is a bug in drawWithoutReplacementSkip() with multiple skipped indexes. Specifically, here's a failing test case:

TEST(drawWithoutReplacementSkip, small_small4) {

  std::vector<size_t> result;
  std::mt19937_64 random_number_generator;
  std::random_device random_device;
  random_number_generator.seed(random_device());
  std::map<size_t, uint> counts;

  size_t max = 9;
  std::vector<size_t> skip = std::vector<size_t>( { 7, 0, 1, 3 });
  size_t num_samples = 4;
  size_t num_replicates = 10000;

  size_t expected_count = num_samples * num_replicates / (max + 1 - skip.size());

  for (size_t i = 0; i < num_replicates; ++i) {
    result.clear();
    drawWithoutReplacementSkip(result, random_number_generator, max + 1, skip, num_samples);
    EXPECT_EQ(num_samples, result.size());
    for (auto& idx : result) {
      EXPECT_LE(idx, max);
      ++counts[idx];
    }
  }

  int total = 0;
  for (const auto& c : counts) {
    printf("%zu : %u\n", c.first, c.second);
    total += c.second;
  }
  printf("Total %d\n", total);

  // Check if counts are expected +- 5%
  for (size_t c = 0; c <= max; ++c) {
    if (std::find(skip.begin(), skip.end(), c) == skip.end()) {
      // c should not be skipped
      EXPECT_NEAR(expected_count, counts[c], expected_count * 0.05);
    } else {
      // c should be skipped
      EXPECT_EQ(0, counts[c]);
    }
  }
}

Which prints the following:

1 : 6741
3 : 6723
4 : 6623
6 : 6693
8 : 6580
9 : 6640
Total 40000

In other words, the wrong indexes are being skipped (should have skipped { 7, 0, 1, 3 }, but actually skipped 0, 2, 5, 7).

@sligocki
Copy link
Contributor

I think the issue here is that the draw >= skip_value check at https://github.com/imbs-hl/ranger/blob/master/src/utility.cpp#L139 should really be checking if the initial drawn value is >= skip_value not the possibly updated version after handling other skip values.

@sligocki
Copy link
Contributor

Hm, this actually appears to be a bug in drawWithoutReplacementFisherYates().

@mnwright
Copy link
Member

It seems the problem is actually drawWithoutReplacementSimple() not drawWithoutReplacementFisherYates(). Both functions work, but:

  • drawWithoutReplacementSimple() expects the skip vector in ascending order
  • drawWithoutReplacementFisherYates() expects the skip vector in descending order

and the always.split.variables are sorted in descending order.

Worst: Nothing of this is documented. :(

@mnwright
Copy link
Member

Now fixed in #738 (and some basic docs that drawWithoutReplacement* expects sorted skip vectors).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants