-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict_bart_df
137 lines (134 loc) · 5.26 KB
/
predict_bart_df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# EDITED FROM:
# edit(predict2.bart)
# https://rdrr.io/github/cjcarlson/embarcadero/src/R/predict.R
predict_bart_df <- function (object, x.layers, quantiles = c(), ri.data = NULL, ri.name = NULL, ri.pred = FALSE, splitby = 1, quiet = FALSE)
{
if (class(object) == "rbart") {
if (is.null(ri.data)) {
stop("ERROR: Input either a value or a raster in ri.data")
}
if (is.null(ri.name)) {
stop("ERROR: Input the correct random effect variable name in the model object in ri.name")
}
}
if (class(object) == "rbart") {
xnames <- attr(object$fit[[1]]$data@x, "term.labels")
if (all(xnames %in% c(names(x.layers), ri.name))) {
x.layers <- x.layers[[xnames[!(xnames == ri.name)]]]
}
else {
stop("Variable names of RasterStack don't match the requested names")
}
}
if (class(object) == "bart") {
xnames <- attr(object$fit$data@x, "term.labels")
if (all(xnames %in% names(x.layers))) {
#x.layers <- x.layers[[xnames]]
x.layers <- x.layers[ , xnames, drop = FALSE] # my add
}
else {
stop("Variable names of RasterStack don't match the requested names")
}
}
#input.matrix <- as.matrix(getValues(x.layers))
input.matrix <- as.matrix(x.layers) # my add
blankout <- data.frame(matrix(ncol = (1 + length(quantiles)),
#nrow = ncell(x.layers[[1]])))
nrow = nrow(x.layers))) # my add
whichvals <- which(complete.cases(input.matrix))
input.matrix <- input.matrix[complete.cases(input.matrix),
]
if (class(object) == "rbart") {
if (class(ri.data) == "RasterLayer") {
input.matrix <- cbind(input.matrix, values(ri.data))
}
else {
input.matrix <- cbind(input.matrix, rep(ri.data,
nrow(input.matrix)))
}
colnames(input.matrix)[ncol(input.matrix)] <- ri.name
}
if (splitby == 1) {
if (class(object) == "bart") {
pred <- dbarts:::predict.bart(object, input.matrix)
# the above causes error with categorical variables in the model (input.matrix has one new variable per category, instead of the input variable names)
# pred <- dbarts:::predict.bart(object, object$fit$data@x) # but this fails when predicting outside the model training data
}
else if (class(object) == "rbart") {
if (ri.pred == FALSE) {
pred <- dbarts:::predict.rbart(object, input.matrix[, !(colnames(input.matrix) == ri.name)], group.by = input.matrix[, ri.name], value = "bart")
}
else {
pred <- dbarts:::predict.rbart(object, input.matrix[, !(colnames(input.matrix) == ri.name)], group.by = input.matrix[, ri.name], value = "ppd")
}
}
pred.summary <- dfextract(pred, quant = quantiles)
}
else {
split <- floor(nrow(input.matrix)/splitby)
input.df <- data.frame(input.matrix)
input.str <- split(input.df, (as.numeric(1:nrow(input.df)) -
1)%/%split)
for (i in 1:length(input.str)) {
if (i == 1) {
start_time <- Sys.time()
}
if (class(object) == "bart") {
pred <- dbarts:::predict.bart(object, input.str[[i]])
}
else if (class(object) == "rbart") {
if (ri.pred == FALSE) {
pred <- dbarts:::predict.rbart(object, input.str[[i]][, !(colnames(input.str[[i]]) == ri.name)], group.by = input.str[[i]][, ri.name], value = "bart")
}
else {
pred <- dbarts:::predict.rbart(object, input.str[[i]][, !(colnames(input.str[[i]]) == ri.name)], group.by = input.str[[i]][, ri.name], value = "ppd")
}
}
pred.summary <- dfextract(pred, quant = quantiles)
input.str[[i]] <- pred.summary
if (i == 1) {
end_time <- Sys.time()
cat("Estimated time to total prediction (mins):\n")
cat(length(input.str) * as.numeric(end_time - start_time)/60)
cat("\n")
if (!quiet) {
pb <- txtProgressBar(min = 0, max = length(input.str), style = 3)
}
}
if (!quiet) {
setTxtProgressBar(pb, i)
}
}
if (length(quantiles) == 0) {
pred.summary <- data.frame(means = unlist(input.str))
}
else {
pred.summary <- rbindlist(input.str)
}
}
output <- as.matrix(pred.summary)
blankout[whichvals, ] <- output
output <- blankout
#outlist <- lapply(1:ncol(output), function(x) {
# output.m <- t(matrix(output[, x], nrow = ncol(x.layers),
# ncol = nrow(x.layers)))
# return(raster(output.m, xmn = xmin(x.layers[[1]]), xmx = xmax(x.layers[[1]]),
# ymn = ymin(x.layers[[1]]), ymx = ymax(x.layers[[1]]),
# crs = x.layers[[1]]@crs))
#})
#outlist <- stack(outlist)
#return(outlist)
#return(output[ , 1]) # my add previously (no quantiles)
if (ncol(output) > 1) { # my add
colnames(output) <- c("pred", paste0("q", quantiles)) # my add
colnames(output) <- gsub("\\.", "", colnames(output)) # my add
} # my add
else output <- output[ , 1] # my add
return(output) # my add
}
dfextract <- function(df, quant) {
# https://rdrr.io/github/cjcarlson/embarcadero/src/R/predict.R
if(length(quant)==0) {return(colMeans(df))} else
return(cbind(data.frame(colMeans(df)),
colQuantiles(df, probs=quant)))
}