Skip to content
Snippets Groups Projects
Verified Commit 8e3be4e1 authored by Gärber, Florian's avatar Gärber, Florian
Browse files

refactor: Add S3 classes to `trees` list objects

parent dfad067f
No related branches found
No related tags found
No related merge requests found
Type: Package
Package: RFSurrogates
Title: Surrogate Minimal Depth Variable Importance
Version: 0.3.3.9003
Version: 0.3.3.9004
Authors@R: c(
person("Stephan", "Seifert", , "stephan.seifert@uni-hamburg.de", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2567-5728")),
......
......@@ -10,6 +10,10 @@
* Clarified default value for `num.threads` to be `parallel::detectCores()` by adding it as a default to the parameter
* Added assertion that `RF` is a `ranger` object.
* Added assertion that `RF$num.trees` and `length(trees)` are equal. This is not considered a breaking change since these values should always be equal when the function is used correctly.
* Added S3 classes to the `trees` list objects.
* `getTreeranger()` now returns a `RangerTrees` list.
* `addLayer()` and `getTreeranger(add_layer = TRUE)` add the `LayerTrees` class to the list (indicating presence of the `layer` list item). It now requires that its `trees` param inherits `RangerTrees`.
* `addSurrogates()` now adds the `SurrogateTrees` class. It now requires that its `trees` param inherits `RangerTrees`.
# RFSurrogates 0.3.3
......
......@@ -19,7 +19,15 @@
#'
#' @export
addLayer <- function(trees, num.threads = 1) {
parallel::mclapply(trees, add_layer_to_tree, mc.cores = num.threads)
if (!inherits(trees, "RangerTrees")) {
stop("`trees` must be a `getTreeranger` `RangerTrees` object.")
}
trees <- parallel::mclapply(trees, add_layer_to_tree, mc.cores = num.threads)
class(trees) <- c(class(trees), "LayerTrees")
return(trees)
}
#' Internal function
......
......@@ -26,6 +26,10 @@ addSurrogates <- function(RF, trees, s, Xdata, num.threads = parallel::detectCor
stop("`RF` must be a ranger object.")
}
if (!inherits(trees, "RangerTrees")) {
stop("`trees` must be a `getTreeranger` `RangerTrees` object.")
}
num.trees <- RF$num.trees
if (num.trees != length(trees)) {
......@@ -54,6 +58,9 @@ addSurrogates <- function(RF, trees, s, Xdata, num.threads = parallel::detectCor
ncat = ncat
)
)
class(trees.surr) <- c(class(trees), "SurrogateTrees")
return(trees.surr)
}
......
......@@ -20,11 +20,18 @@
#'
#' @export
getTreeranger <- function(RF, num.trees = RF$num.trees, add_layer = FALSE, num.threads = 1) {
parallel::mclapply(1:num.trees, getsingletree,
trees <- parallel::mclapply(1:num.trees, getsingletree,
mc.cores = num.threads,
RF = RF,
add_layer = add_layer
)
class(trees) <- "RangerTrees"
if (add_layer) {
class(trees) <- c(class(trees), "LayerTrees")
}
return(trees)
}
#' getsingletree
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment