forked from jtr13/bookdown-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path03-regression.Rmd
379 lines (285 loc) · 16.8 KB
/
03-regression.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
# Regression Example
```{r, include=FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
warning = FALSE,
message = FALSE,
fig.height = 4.75,
fig.width = 6.25,
fig.align = 'center')
```
## Background & Data
This regression example is inspired by the blog post "The Effect of Childhood Education on Wealth: Modeling with Bayesian Additive Regression Trees (BART)" on R-bloggers by Selcuk Disci. We aim to explore how enrollment rates in early childhood education are associated with household net worth for all countries involved in the Organisation of Economic Co-operation and Development (OECD), using datasets provided by this organization. The OECD collects data on enrollment rates in early childhood education and household net worth from its member countries every year from 2000 to 2020, through standardized surveys and national statistics, ensuring consistent and comparable data across all countries involved.
According to OECD, the **enrollment rates** for each childhood age group, 3-year-old, 4-year-old, and 5-year-old are calculated by dividing the number of children enrolled in early childhood education and care (ECEC) of a particular age group by the total population of that age group. It is notable that this calculation does not distinguish between full-time and part-time enrollment.
The **household net worth** indicator first calculates the overall financial status of households by measuring the total value of their assets (both financial, like stocks and savings, and non-financial, like real estate) and subtracting the total value of their outstanding debts (such as loans and mortgages). This result is then presented as a **percentage** of the households' annual income. Essentially, this indicator provides a snapshot of the economic health and financial stability of households by showing how much wealth they have in relation to how much they earn each year. By definition, the unit of household net worth is percent.
Here, we demonstrate how to use the Bayesian Additive Regression Trees (BART) model to capture the relationship between enrollment rates in early childhood education in a country and its average household net worth.
\
\
## Data Wrangling
We start here by loading some packages for the data and modeling, as well as cleaning and wrangling the data.
Note: Code chunks with the comments saying "Rbloggers code is from ["The Effect of Childhood Education on Wealth: Modeling with Bayesian Additive Regression Trees (BART)" on R-bloggers by Selcuk Disci.](https://www.r-bloggers.com/2022/12/the-effect-of-childhood-education-on-wealth-modeling-with-bayesian-additive-regression-trees-bart/#google_vignette)
```{r}
# Rbloggers code
library(tidyverse)
library(tidymodels)
library(ggplot2)
library(countrycode)
library(plotly)
library(sysfonts)
library(showtext)
library(glue)
library(scales)
library(janitor)
library(DALEXtra)
library(dbarts)
#Loading the datasets
df_childhood <- read_csv("https://raw.githubusercontent.com/mesdi/blog/main/childhood.csv")
df_household <- read_csv("https://raw.githubusercontent.com/mesdi/blog/main/household.csv")
#Joining them by country and time
df <-
df_childhood %>%
left_join(df_household, by = c("country", "time")) %>%
na.omit()
#Wrangling the dataset
df_tidy <-
df %>%
mutate(household = round(household, 2),
childhood = round(childhood, 2),
age = str_replace(age, "_", "-"),
country_name = countrycode(country, "genc3c", "country.name")
)
#Best 20 countries based on the household net worth in their last year
df_tidy %>%
group_by(country) %>%
slice_max(time) %>%
slice_max(household, n=20) %>%
mutate(age = fct_reorder(age, childhood, .desc = TRUE),
country_name = fct_reorder(country_name, household, .desc = TRUE)) %>%
ggplot(aes(x=country_name,
y=childhood,
fill = age,
#Hover text of the barplot
text = glue("{country}\n%{childhood}\n{age}\nChildhood education"))) +
geom_col() +
geom_line(aes(y=household/2, group = 1),
color= "skyblue",
size=1) +
#Adding the household net worth as a second(dual) y-axis
scale_y_continuous(sec.axis = sec_axis(~.*2)) +
scale_fill_viridis_d(name = "") +
xlab("") +
ylab("") +
theme_minimal() +
theme(
axis.text.x = element_text(angle = 60),
axis.text.y = element_blank(),
axis.text.y.right = element_blank(),
panel.grid = element_blank(),
legend.position = "none"
) -> p
#adding google font
font_add_google(name = "Henny Penny", family = "henny")
showtext_auto()
#setting font family for ggplotly
font <- list(
family= "Henny Penny",
size =5
)
```
Let's take a look at the cleaned dataset:
```{r}
head(df_tidy)
```
Before modeling the relationship between average household net worth and the enrollment rate of childhood education for different age groups, we can plot this information together to get some initial insights.
```{r}
# Rbloggers code
#Plotly chart
ggplotly(p, tooltip = c("text")) %>%
#Hover text of the line
style(text = glue("{unique(p$data$country)}\n%{unique(p$data$household)}\nHousehold net worth"),traces = 6) %>%
layout(font=font)
```
In the graph above, yellow portions of the bars represent childhood enrollment rate for 3-year-olds, teal represents that for 4-year-olds, and purple represents that for 5-year-olds. The light blue trend line shows the average household net worth for each country in the dataset. Before modeling anything, a vague and general takeaway is that the childhood enrollment rate seems to be positively related to the average household net worth in that country.
\
\
With that in mind, we will model this relationship with BART in the next section.
\
## Implementation
BART has been implemented in various R packages. In order to run the two packages we use in this tutorial, which are `dbarts` and `tidymodels`, we start by splitting our data into a training set and and a test set. This way we can train the model and evaluate the performance on the test set using cross validation.
```{r}
# Rbloggers code
#Splitting the data into train and test sets
set.seed(1234)
df_split <-
df_tidy %>%
#Converting the levels to variables for modeling
pivot_wider(names_from = age, values_from = childhood) %>%
clean_names() %>%
na.omit() %>%
initial_split()
df_train <- training(df_split)
df_test <- testing(df_split)
```
### `dbarts` package
The function we are using under the `dbarts` package is called `bart`. To fit a BART model, we need to feed in the independent variables for the `x.train` parameter and the response variable for the `y.train` parameter.
In our case, the independent variables are the age 3, age 4, and age 5 education enrollment rates, and the response variable is the household net worth. To input these information into the `bart` function, we specify their column indices to access them.
We also set `keeptrees = TRUE` in order for the resulting fitted model to store the splitting information and the configuration of each single tree that we have. We did this because we would like to view some examples of what variable a tree chose, what the split rule is, and how long before the tree terminates.
The parameter `ndpost` specifies the number of iterations we want after the burn-in period, while the default number of iterations inside the burn-in period is 100.
The printed descriptive output of the `bart` function shows the specific parameters that were used to fit the BART model, which were introduced in our methodology chapter.
```{r}
# Fitting a BART model with default 1,000 iterations of 200 trees
set.seed(4343)
bartFit <- bart(x.train = as.matrix(df_train[,5:7]), y.train = as.numeric(unlist(df_train[,3])),keeptrees = TRUE, ndpost = 1000)
# Extracting trees from model
trees <- extract(bartFit, "trees")
# Looking at some examples of trees from model
bartFit$fit$plotTree(chainNum = 1, sampleNum = 3, treeNum = 1)
bartFit$fit$plotTree(chainNum = 1, sampleNum = 3, treeNum = 140)
```
\
\
### `bartMan` package
Using the `bartMan` package, we can look at more interesting visualizations that show us some model diagnostics as well as what our trees look like in a fitted model.
```{r, eval=FALSE}
# Loading packages for BART visualization
library(bartMan)
library(ggridges)
bartDiag(model = bartFit, response = "household", burnIn = 1000, data = df_tidy)
```
\
```{r, echo=FALSE,, out.width = '80%', fig.align = 'center'}
knitr::include_graphics("image/regr-diag.png")
```
Shown above are six general diagnostic plots for the BART regression fit on our dataset. Top left: A QQ-plot of the residuals after fitting the model. Top right: trace plot of $\sigma$ from MCMC iteration. Middle left: Residuals versus fitted values with 95% credible intervals. Middle right: A histogram of the residuals. Bottom Left: Actual values versus fitted values with 95% credible intervals. (This is unfornately not working properly due to unknown reasons). Bottom right: Variable importance plot with 25 to 75% quantile interval shown. In our case, the BART model decides that the enrollment rate of childhood education for 3-year-olds is the most important variable.
\
\
Then, we want to visualize all the trees we fitted across all iterations. For less computational time and simplicity, we reduce the number of trees as well as the number of iterations for each tree. With `dbarts`, we fit a new model with the same variables with 50 trees for 10 iterations.
```{r}
# Fitting another BART model with fewer trees and less iterations
set.seed(4343)
bartFit50 <- bart(x.train = as.matrix(df_train[,5:7]), y.train = as.numeric(unlist(df_train[,3])), keeptrees = TRUE, ntree = 50, ndpost = 10, verbose = FALSE)
```
```{r, eval = FALSE}
# Extracting the tree data
trees_data50 <- extractTreeData(bartFit50, df_tidy)
# Visualizing what each of the 50 trees look like over their 10 iterations
plotTrees(trees = trees_data50, fillBy = NULL, sizeNodes = TRUE)
```
```{r, echo=FALSE,, out.width = '80%', fig.align = 'center'}
knitr::include_graphics("image/regr-all-tree.png")
```
In the plot above, each little box represents a single tree. There are 500 boxes, thus 500 = 50 $\cdot$ 10 trees, because it is showing all the trees (50 trees) built in all iterations after the burn-in period (10 iterations) in the reduced model. Different colors represent different variables that a tree is splitting on, and the gray represents a stump/leaf, or a terminal node.
While this plot is useful, it is hard to get insight from it because that is too much to look at at one time. To deal with that, we can specify one tree to look at all the iterations (10) of it. Here we are using the $13^{th}$ as an example:
```{r, eval = FALSE}
# Viewing all 10 iterations of one tree
plotTrees(trees = trees_data50, treeNo = 13)
```
```{r, echo=FALSE,, out.width = '80%', fig.align = 'center'}
knitr::include_graphics("image/regr-one-tree.png")
```
However, that is only the information of one tree. An even better alternative to grab information efficiently is to use the `treeBarPlot` function in the `bartMan` package. This function creates a bar plot that shows how many times each specific structure of tree, including which variable the tree splits on (but not its splitting rule value), shows up within the model. We also create a density plot that shows the splitting variables and the frequency at which each splitting rule value is chosen.
```{r, eval = FALSE}
# Creating bar plot showing frequency of 8 most common trees from model
treeBarPlot(trees_data50, topTrees = 8, iter = NULL)
# Creating density plot of variable split levels values
splitDensity(trees = trees_data50, data = df_tidy, display = 'ridge')
```
```{r, echo=FALSE,, out.width = '80%', fig.align = 'center'}
knitr::include_graphics("image/regr-bar.png")
```
```{r, echo=FALSE,, out.width = '80%', fig.align = 'center'}
knitr::include_graphics("image/regr-dens.png")
```
Note: These visuals can be interesting to look at and interpret, but there is a discrepancy in the colors that are shown between the 500-tree plot, the density plot, and the bar plot. There is no legend argument for the `treeBarPlot`, so we manually added in the legends for these later plots. We made sure that the variables match the colors as described in the legend by checking [original code](https://github.com/AlanInglis/bartMan/blob/master/R/treeBarPlot.R) of the function `treeBarPlot`.
\
\
### `tidymodels` package
Next, we decide to use the `tidymodels` package to fit our BART model and assess its overall performance. This is because `tidymodels` is more comprehensive and provides an standardized workflow for not only training and testing the model, but also performing cross-validation on the model. The steps for this package are universal for machine learning in R.
```{r}
# Rbloggers code
#Preprocessing
df_rec <- recipe(household ~ age_3 + age_4 + age_5, data = df_train)
#Modeling with BART
df_spec <-
parsnip::bart() %>%
set_engine("dbarts", keeptrees = TRUE) %>%
set_mode("regression")
#Workflow
df_wf <-
workflow() %>%
add_recipe(df_rec) %>%
add_model(df_spec)
#cross-validation for resamples
set.seed(12345)
df_folds <- vfold_cv(df_train)
#Resampling for the accuracy metrics
set.seed(98765)
df_rs <-
df_wf %>%
fit_resamples(resamples = df_folds)
#Computes the accuracy metrics
collect_metrics(df_rs)
```
The output we see above are quality metrics for our model. Our RMSE average value shows the average difference between the observed and predicted values of household net worth to be 85.2 percent, since the household net worth is a percentage by its definition. That is, the prediction we get from this built BART model is off by 85.2 percent on average. Given that the household net worth ranges from 134 percent to 824 percent in our dataset, we think the BART model has a decent performance.
Additionally, the $R^2$ value provides us insights on how well the model fit our data. We observe a value of 0.528, which means that 52.8% of the data can be explained by the fitted BART model. We would like to be closer to 1 in order for the model to explain more variability in the data.
\
\
To refine our BART model, we can also tune the priors instead of simply using the default ones. Here, we demonstrate how to use grid search to tune the priors of the BART model. These priors include the number of trees, terminal node coefficient, and the exponential component of the prior distribution for these terminal node parameters.
```{r}
# Rbloggers code
# Model tuning with grid search
df_spec <-
parsnip::bart(
trees = tune(),
prior_terminal_node_coef = tune(),
prior_terminal_node_expo = tune()
) %>%
set_engine("dbarts") %>%
set_mode("regression")
#parameter object
rf_param <-
workflow() %>%
add_model(df_spec) %>%
add_recipe(df_rec) %>%
extract_parameter_set_dials() %>%
finalize(df_train)
#space-filling design with integer grid argument
df_reg_tune <-
workflow() %>%
add_recipe(df_rec) %>%
add_model(df_spec) %>%
tune_grid(
df_folds,
grid = 20,
param_info = rf_param,
metrics = metric_set(rsq)
)
#Selecting the best parameters according to the r-square
rf_param_best <-
select_best(df_reg_tune, metric = "rsq") %>%
select(-.config)
#Final estimation with the object of best parameters
final_df_wflow <-
workflow() %>%
add_model(df_spec) %>%
add_recipe(df_rec) %>%
finalize_workflow(rf_param_best)
set.seed(12345)
final_df_fit <-
final_df_wflow %>%
last_fit(df_split)
#Computes final the accuracy metrics
collect_metrics(final_df_fit)
```
The RMSE average value shows the difference between the observed and predicted values of household net worth to be 84.5 percent. This is just slightly better than the previous BART model fitted with `tidymodels`.
For this second model, we observe a value of 0.645 for $R^2$, which is better than the first model.
\
\
## References
- AlanInglis. (n.d.). GitHub - AlanInglis/bartMan: Visualisations for posterior evaluation of BART models. GitHub. <https://github.com/AlanInglis/bartMan?tab=readme-ov-file>
- Bayesian additive regression trees (BART) - bart. - bart • parsnip. (n.d.). <https://parsnip.tidymodels.org/reference/bart.html>
- Disci, S. (2022, December 8). The effect of childhood education on wealth: Modeling with bayesian additive regression trees (BART): R-bloggers. R. <https://www.r-bloggers.com/2022/12/the-effect-of-childhood-education-on-wealth-modeling-with-bayesian-additive-regression-trees-bart/#google_vignette>
- Inglis, A., Parnell, A. C., & Hurley, C. (2024). Visualisations for Bayesian Additive Regression Trees. Journal of Data Science, Statistics, and Visualisation, 4(1). <https://doi.org/10.52933/jdssv.v4i1.79>
- Introduction to palmerpenguins. (n.d.). <https://allisonhorst.github.io/palmerpenguins/articles/intro.html>
- OECD (2024), Enrolment rate in early childhood education (indicator). doi: 10.1787/ce02d0f9-en (Accessed on 30 April 2024)
- OECD (2024), Household net worth (indicator). doi: 10.1787/2cc2469a-en (Accessed on 30 April 2024)