diff --git a/DESCRIPTION b/DESCRIPTION index e5882f1929827ddf9f1e559855ca461f2eb9d715..2ceaf979220381daf728acf39815ad15b655dce9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Type: Package Package: RFSurrogates Title: Surrogate Minimal Depth Variable Importance -Version: 0.3.3.9007 +Version: 0.3.3.9008 Authors@R: c( person("Stephan", "Seifert", , "stephan.seifert@uni-hamburg.de", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2567-5728")), diff --git a/R/RandomForestSurrogates.R b/R/RandomForestSurrogates.R index e57daa73752a9403e9a966730b19dacfa369b37b..b0cb40feb2eb210c3b73435c9f0ecce3fb5f4d7e 100644 --- a/R/RandomForestSurrogates.R +++ b/R/RandomForestSurrogates.R @@ -36,6 +36,8 @@ #' #' @param num.threads Number of threads to parallelize with. (Default: 1) #' +#' @param preschedule.threads (Default: TRUE) Passed as `mc.preschedule` to [parallel::mclapply()] in [addSurrogates()]. +#' #' @returns A RandomForestSurrogates S3 object. #' * `trees`: List of all trees with surrogate analysis. (Class: `SurrogateTrees`, `LayerTrees`, `RangerTrees`) #' * `ranger`: The [ranger::ranger] model used to obtain the trees. @@ -55,6 +57,7 @@ RandomForestSurrogates <- function( min.node.size = 1, permutate = FALSE, seed = NULL, + preschedule.threads = TRUE, ...) { if (length(y) != nrow(x)) { stop(paste0("Different numbers of response variables and observations.\nFound: nrow(x) = ", nrow(x), ", length(y) = ", length(y), ", expected them to be equal.")) @@ -161,7 +164,8 @@ RandomForestSurrogates <- function( RF = RF, s = s, x = x, - num.threads = num.threads + num.threads = num.threads, + preschedule.threads = preschedule.threads ) result <- list( @@ -177,7 +181,13 @@ RandomForestSurrogates <- function( #' @returns A `SurrogateTrees`, `LayerTrees`, `RangerTrees` object. #' #' @keywords internal -getSurrogateTrees <- function(RF, s, x, num.threads = parallel::detectCores()) { +getSurrogateTrees <- function( + RF, + s, + x, + num.threads = parallel::detectCores(), + preschedule.threads = TRUE +) { if (!inherits(RF, "ranger")) { stop("`RF` must be a `ranger` object.") } @@ -191,7 +201,8 @@ getSurrogateTrees <- function(RF, s, x, num.threads = parallel::detectCores()) { ), s = s, Xdata = x, - num.threads = num.threads + num.threads = num.threads, + preschedule.threads = preschedule.threads ) } diff --git a/R/addSurrogates.R b/R/addSurrogates.R index bea7a0787f7a501a82b6a787f8757919f6326627..4c802da4596afd39bd7b4cf15e17839e0f6d9b67 100644 --- a/R/addSurrogates.R +++ b/R/addSurrogates.R @@ -7,6 +7,7 @@ #' @param s Predefined number of surrogate splits (it may happen that the actual number of surrogate splits differs in individual nodes). #' @param Xdata data without the dependent variable. #' @param num.threads (Default: [parallel::detectCores()]) Number of threads to spawn for parallelization. +#' @param preschedule.threads (Default: TRUE) Passed as `mc.preschedule` to [parallel::mclapply()]. #' #' @returns A list of trees. #' A list of trees containing of lists of nodes with the elements: @@ -21,7 +22,14 @@ #' * `adj_i`: adjusted agreement of variable i #' #' @export -addSurrogates <- function(RF, trees, s, Xdata, num.threads = parallel::detectCores()) { +addSurrogates <- function( + RF, + trees, + s, + Xdata, + num.threads = parallel::detectCores(), + preschedule.threads = TRUE +) { if (!inherits(RF, "ranger")) { stop("`RF` must be a ranger object.") } @@ -46,17 +54,17 @@ addSurrogates <- function(RF, trees, s, Xdata, num.threads = parallel::detectCor # variables to find surrogates (control file similar as in rpart) controls <- list(maxsurrogate = as.integer(s), sur_agree = 0) - trees.surr <- parallel::mclapply(1:num.trees, - getSurrogate, + trees.surr <- parallel::mclapply( + X = mapply(list, .a = trees, .b = RF$inbag.counts, SIMPLIFY = FALSE), + FUN = getSurrgate2, + mc.cores = num.threads, - maxsurr = s, - surr.par = list( - inbag.counts = RF$inbag.counts, - Xdata = Xdata, - controls = controls, - trees = trees, - ncat = ncat - ) + mc.preschedule = preschedule.threads, + + Xdata = Xdata, + controls = controls, + s = s, + ncat = ncat ) class(trees.surr) <- c(class(trees), "SurrogateTrees") @@ -64,6 +72,35 @@ addSurrogates <- function(RF, trees, s, Xdata, num.threads = parallel::detectCor return(trees.surr) } +#' getSurrogate2 +#' +#' This is an internal function +#' +#' @param x List of length `num.trees` +#' +#' @keywords internal +getSurrgate2 <- function( + x, # list of length num.trees with [1] tree and [2] inbag.counts + Xdata, + controls, + s, + ncat +) { + tree <- x[[1]] + lapply( + X = seq_len(nrow(tree)), + FUN = SurrTree, + + tree = tree, + wt = x[[2]], + Xdata = Xdata, + controls = controls, + column.names = colnames(tree), + maxsurr = s, + ncat = ncat + ) +} + #' getSurrogate #' #' This is an internal function diff --git a/man/MFI.Rd b/man/MFI.Rd index 3c7910b23e5482192cd6e5ca11e55a84dd1c84f8..c085ca693b3668292cdc8263630a675ac2a6b47d 100644 --- a/man/MFI.Rd +++ b/man/MFI.Rd @@ -38,6 +38,7 @@ Use 1 for event and 0 for censoring. Length must match \code{y}.} \item{\code{min.node.size}}{Minimal node size to split at. (Default: 1)} \item{\code{permutate}}{Enable to permutate \code{x} for \code{\link[=MutualForestImpact]{MutualForestImpact()}} (Default: FALSE).} \item{\code{seed}}{RNG seed. It is strongly recommended that you set this value.} + \item{\code{preschedule.threads}}{(Default: TRUE) Passed as \code{mc.preschedule} to \code{\link[parallel:mclapply]{parallel::mclapply()}} in \code{\link[=addSurrogates]{addSurrogates()}}.} \item{\code{num.trees}}{Number of trees.} }} } diff --git a/man/RandomForestSurrogates.Rd b/man/RandomForestSurrogates.Rd index a42d22fa6d0eaf1d982f29a179341e29c0fcdb44..c9c6e3116daf9a893e6e9c20e747db73f2127530 100644 --- a/man/RandomForestSurrogates.Rd +++ b/man/RandomForestSurrogates.Rd @@ -17,6 +17,7 @@ RandomForestSurrogates( min.node.size = 1, permutate = FALSE, seed = NULL, + preschedule.threads = TRUE, ... ) } @@ -55,6 +56,8 @@ Use 1 for event and 0 for censoring. Length must match \code{y}.} \item{seed}{RNG seed. It is strongly recommended that you set this value.} +\item{preschedule.threads}{(Default: TRUE) Passed as \code{mc.preschedule} to \code{\link[parallel:mclapply]{parallel::mclapply()}} in \code{\link[=addSurrogates]{addSurrogates()}}.} + \item{...}{ Arguments passed on to \code{\link[ranger:ranger]{ranger::ranger}} \describe{ diff --git a/man/addSurrogates.Rd b/man/addSurrogates.Rd index 9dddcf59e15c4027c688d96a3e14056007cecf12..4841a646eb4380fcecd3adda410167ae04f49ff0 100644 --- a/man/addSurrogates.Rd +++ b/man/addSurrogates.Rd @@ -4,7 +4,14 @@ \alias{addSurrogates} \title{Add surrogate information to a tree list.} \usage{ -addSurrogates(RF, trees, s, Xdata, num.threads = parallel::detectCores()) +addSurrogates( + RF, + trees, + s, + Xdata, + num.threads = parallel::detectCores(), + preschedule.threads = TRUE +) } \arguments{ \item{RF}{A \link[ranger:ranger]{ranger::ranger} object which was created with \code{keep.inbag = TRUE}.} @@ -16,6 +23,8 @@ addSurrogates(RF, trees, s, Xdata, num.threads = parallel::detectCores()) \item{Xdata}{data without the dependent variable.} \item{num.threads}{(Default: \code{\link[parallel:detectCores]{parallel::detectCores()}}) Number of threads to spawn for parallelization.} + +\item{preschedule.threads}{(Default: TRUE) Passed as \code{mc.preschedule} to \code{\link[parallel:mclapply]{parallel::mclapply()}}.} } \value{ A list of trees. diff --git a/man/getSurrgate2.Rd b/man/getSurrgate2.Rd new file mode 100644 index 0000000000000000000000000000000000000000..25c5762e2a2a4139b75d37517923044b65d7832e --- /dev/null +++ b/man/getSurrgate2.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/addSurrogates.R +\name{getSurrgate2} +\alias{getSurrgate2} +\title{getSurrogate2} +\usage{ +getSurrgate2(x, Xdata, controls, s, ncat) +} +\arguments{ +\item{x}{List of length \code{num.trees}} +} +\description{ +This is an internal function +} +\keyword{internal}