Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BB refactoring #16

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Imports:
checkmate,
colorspace,
data.table,
ggplot2,
mlr3,
mlr3misc,
paradox,
Expand All @@ -34,8 +35,10 @@ Suggests:
mlr3learners,
mlr3pipelines,
patchwork,
rmarkdown
rmarkdown,
testthat (>= 3.0.0)
VignetteBuilder:
knitr
Encoding: UTF-8
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Config/testthat/edition: 3
15 changes: 7 additions & 8 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

S3method(as.data.table,DictionaryLoss)
S3method(as.data.table,DictionaryObjective)
S3method(as_visualizer,LossFunction)
S3method(as_visualizer,Objective)
S3method(as_visualizer,Task)
S3method(c,VisualizerLossFunction)
export(LearnerRegrLMFormula)
export(LossFunction)
export(Objective)
Expand All @@ -15,13 +11,15 @@ export(OptimizerMomentum)
export(OptimizerNAG)
export(Visualizer1D)
export(Visualizer1DModel)
export(Visualizer1DObjective)
export(Visualizer1DObj)
export(Visualizer2D)
export(Visualizer2DModel)
export(Visualizer2DObjective)
export(VisualizerLossFunction)
export(Visualizer2DObj)
export(Visualizer3D)
export(Visualizer3DModel)
export(Visualizer3DObjective)
export(VisualizerLossFuns)
export(as.data.table)
export(as_visualizer)
export(assertStepSizeControl)
export(colSampler)
export(data.table)
Expand All @@ -39,6 +37,7 @@ import(TestFunctions)
import(checkmate)
import(colorspace)
import(data.table)
import(ggplot2)
import(mlr3)
import(mlr3misc)
import(paradox)
Expand Down
36 changes: 19 additions & 17 deletions R/LossFunction.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#FIXME: doc API of funs better


#' @title Loss Function
#'
Expand All @@ -22,7 +24,7 @@ LossFunction = R6::R6Class("LossFunction",

#' @field properties `character()`\cr
#' Additional properties of the loss function.
properties = NULL,
task_type = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
Expand All @@ -35,11 +37,11 @@ LossFunction = R6::R6Class("LossFunction",
#' Additional properties of the loss function.
#' @param fun (`function(y_true, y_pred, ...)`)\cr
#' Loss function.
initialize = function(id, label, properties, fun) {
initialize = function(id, label, task_type, fun) {
self$id = assert_character(id)
self$fun = assert_function(fun)
self$label = assert_character(label)
self$properties = assert_character(properties)
self$task_type = assert_choice(task_type, c("regr", "classif"))
self$fun = assert_function(fun)
}
)
)
Expand Down Expand Up @@ -68,29 +70,29 @@ lss = function(.key, ...) {
dict_loss$get(.key, ...)
}

dict_loss$add("l2_se", LossFunction$new("l2_se", "L2 Squared Error", "regr", function(y_true, y_pred) {
(y_true - y_pred)^2
dict_loss$add("l2_se", LossFunction$new("l2", "L2 Squared Error", "regr", function(r) {
(r)^2
}))

dict_loss$add("l1_ae", LossFunction$new("l1_ae", "L1 Absolute Error", "regr", function(y_true, y_pred) {
abs(y_true - y_pred)
dict_loss$add("l1_ae", LossFunction$new("l1", "L1 Absolute Error", "regr", function(r) {
abs(r)
}))

dict_loss$add("huber", LossFunction$new("huber", "Huber Loss", "regr", function(y_true, y_pred, delta = 1) {
a = abs(y_true - y_pred)
dict_loss$add("huber", LossFunction$new("huber", "Huber Loss", "regr", function(r, delta = 1) {
a = abs(r)
ifelse(a <= delta, 0.5 * a^2, delta * a - delta^2 / 2)
}))

dict_loss$add("log-cosh", LossFunction$new("log-cosh", "Log-Cosh Loss", "regr", function(y_true, y_pred) {
log(cosh(y_pred - y_true))
dict_loss$add("log-cosh", LossFunction$new("logcosh", "Log-Cosh Loss", "regr", function(r) {
log(cosh(r))
}))

dict_loss$add("cross-entropy", LossFunction$new("cross-entropy", "Cross-Entropy", "classif", function(y_true, y_pred) {
log(1 + exp(-y_true * y_pred))
dict_loss$add("cross-entropy", LossFunction$new("logloss", "Log Loss", "classif", function(r) {
log(1 + exp(-r))
}))

dict_loss$add("hinge", LossFunction$new("hinge", "Hinge Loss", "classif", function(y_true, y_pred) {
pmax(1 - y_true * y_pred, 0)
dict_loss$add("hinge", LossFunction$new("hinge", "Hinge Loss", "classif", function(r) {
pmax(1 - r, 0)
}))


Expand All @@ -101,7 +103,7 @@ as.data.table.DictionaryLoss = function(x, ..., objects = FALSE) {
setkeyv(map_dtr(x$keys(), function(key) {
t = x$get(key)
insert_named(
list(key = key, label = t$label, properties = list(t$properties)),
list(key = key, label = t$label, task_type = t$task_type),
if (objects) list(object = list(t))
)
}, .fill = TRUE), "key")[]
Expand Down
128 changes: 64 additions & 64 deletions R/Objective.R

Large diffs are not rendered by default.

187 changes: 105 additions & 82 deletions R/Visualizer1D.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,102 +12,125 @@
Visualizer1D = R6::R6Class("Visualizer1D",
public = list(

#' @field x (`vector()`)\cr
#' x-values.
x = NULL,
#' @field x (`numeric(n)`)\cr
#' x-values of function
fun_x = NULL,

#' @field y (`vector()`)\cr
#' y-values
y = NULL,
#' @field y (`numeric(n)`)\cr
#' y-values of function
fun_y = NULL,

#' @field plot_lab (character(1)\cr
#' Label of the plot.
plot_lab = NULL,
#' @field title (character(1)\cr
#' Title of plot
title = NULL,

#' @field x_lab (character(1)\cr
#' Label of the x axis.
x_lab = NULL,
#' @field lab_x (character(1)\cr
#' Label of x-axis
lab_x = NULL,
# FIXME: make consistent names with other visualizers

#' @field y_lab (character(1)\cr
#' Label of the y axis.
y_lab = NULL,
#' @field lab_y (character(1)\cr
#' Label of y-axis
lab_y = NULL,

#' @field x (`numeric(m)`)\cr
#' x-values of extra points to plot.
#' Use NULL if no points should be plotted.
points_x = NULL,

#' @field y (`numeric(m)`)\cr
#' y-values of extra points to plot.
#' Use NULL if no points should be plotted.
points_y = NULL,

#' @field line_col (character(1)\cr
#' Color of plotted line
line_col = NULL,

#' @field line_width (numeric(1)\cr
#' Width of plotted line
line_width = NULL,

#' @field line_type (character(1)\cr
#' Type of plotted line
line_type = NULL,

#' @field points_col (character(1)\cr
#' Color of plotted points
points_col = NULL,

#' @field points_size (numeric(1)\cr
#' Size of plotted points
points_size = NULL,

#' @field points_shape (integer(1)\cr
#' Shape of plotted points
points_shape = NULL,

#' @field points_alpha (numeric(1)\cr
#' Alpha blending of plotted points
points_alpha = NULL,

#FIXME: add point-size, point col, point-symbol

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param x (`numeric()`)\cr
#' x-values.
#' x-values of function
#' @param y (`numeric()`)\cr
#' y-values.
#' @param plot_lab (character(1)\cr
#' Label of the plot.
#' @param x_lab (character(1)\cr
#' Label of the x axis.
#' @param y_lab (character(1)\cr
#' Label of the y axis.
initialize = function(x, y, plot_lab = NULL, x_lab = "x", y_lab = "y") {
self$x = assert_numeric(x)
self$y = assert_numeric(y)
self$plot_lab = assert_character(plot_lab, null.ok = TRUE)
self$x_lab = assert_character(x_lab)
self$y_lab = assert_character(y_lab)
return(invisible(self))
},

#' @description
#' Initialize the plot with a line plot.
#'
#' @param ... (`any`)\cr
#' Further arguments passed to `add_trace(...)`.
init_layer_lines = function(...) {
private$.plot = plot_ly() %>%
add_trace(
name = self$plot_lab,
showlegend = FALSE,
x = self$x,
y = self$y,
type = "scatter",
mode = "lines",
...
) %>%
layout(
title = self$plot_lab,
xaxis = list(title = self$x_lab),
yaxis = list(title = self$y_lab))

return(invisible(self))
#' y-values of function
#' @param title (character(1)\cr
#' Title of plot
#' @param lab_x (character(1)\cr
#' Label of x-axis
#' @param lab_y (character(1)\cr
#' Label of y-axis
#' @param points_x (`numeric()`)\cr
#' x-values of extra points to plot.
#' Use NULL if no points should be plotted.
#' @param points_y (`numeric()`)\cr
#' y-values of extra points to plot.
#' Use NULL if no points should be plotted.
initialize = function(
fun_x,
fun_y,
title = NULL,
lab_x = "x",
lab_y = "y",
points_x = NULL,
points_y = NULL
) {
self$fun_x = assert_numeric(fun_x)
self$fun_y = assert_numeric(fun_y)
self$title = assert_character(title, null.ok = TRUE)
self$lab_x = assert_character(lab_x)
self$lab_y = assert_character(lab_y)
self$points_x = assert_numeric(points_x, null.ok = TRUE)
self$points_y = assert_numeric(points_y, null.ok = TRUE)
self$line_type = "solid"
self$line_col = "red"
self$line_width = 3
self$points_shape = 19
self$points_col = "black"
self$points_size = 2
self$points_alpha = 0.3
},

#' @description
#' Set the layout of the plotly plot.
#'
#' @param ... (`any`)\cr
#' Layout options directly passed to `layout(...)`.
setLayout = function(...) {
private$p_layout = list(...)
private$p_plot = private$p_plot %>% layout(...)

return(invisible(self))
},
# FIXME: set better defaults here to make plot nicer, maybe ask lukas

#' @description
#' Return the plot and hence plot it or do further processing.
plot = function() {
if (is.null(private$.plot)) self$init_layer_lines()
return(private$.plot)
},

#' @description
#' Save the plot by using plotlys `orca()` function.
#'
#' @param ... (`any`)\cr
#' Further arguments passed to `orca()`.
save = function(...) {
orca(private$.plot, ...)
dd = data.frame(x = self$fun_x, y = self$fun_y)
pl = ggplot(data = dd, aes(x = x, y = y))
pl = pl + geom_line(size = self$line_width, col = self$line_col, linetype = self$line_type)
# use specified axis labels and legend title
pl = pl + labs(title = self$title, x = self$lab_x, y = self$lab_y)
if (!is.null(self$points_x)) {
dd2 = data.frame(x = self$points_x, y = self$points_y)
pl = pl + geom_point(data = dd2, size = self$points_size, col = self$points_col,
shape = self$points_shape, alpha = self$points_alpha)
}
return(pl)
}
),

private = list(
.plot = NULL
)
)
Loading