Something went wrong on our end
Select Git revision
addSurrogates.R
-
Gärber, Florian authoredGärber, Florian authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
addSurrogates.R 5.82 KiB
#' Add surrogate information to a tree list.
#'
#' This function adds surrogate variables and adjusted agreement values to a forest that was created by [getTreeranger].
#'
#' @param RF A [ranger::ranger] object which was created with `keep.inbag = TRUE`.
#' @param trees List of trees created by [getTreeranger].
#' @param s Predefined number of surrogate splits (it may happen that the actual number of surrogate splits differs in individual nodes).
#' @param Xdata data without the dependent variable.
#' @param num.threads (Default: [parallel::detectCores()]) Number of threads to spawn for parallelization.
#' @param preschedule.threads (Default: TRUE) Passed as `mc.preschedule` to [parallel::mclapply()].
#'
#' @returns A list of trees.
#' A list of trees containing of lists of nodes with the elements:
#' * `nodeID`: ID of the respective node (important for left and right daughters in the next columns)
#' * `leftdaughter`: ID of the left daughter of this node
#' * `rightdaughter`: ID of the right daughter of this node
#' * `splitvariable`: ID of the split variable
#' * `splitpoint`: splitpoint of the split variable
#' * `status`: `0` for terminal and `1` for non-terminal
#' * `layer`: layer information (`0` means root node, `1` means 1 layer below root, etc)
#' * `surrogate_i`: numbered surrogate variables (number depending on s)
#' * `adj_i`: adjusted agreement of variable i
#'
#' @export
addSurrogates <- function(
RF,
trees,
s,
Xdata,
num.threads = parallel::detectCores(),
preschedule.threads = TRUE
) {
if (!inherits(RF, "ranger")) {
stop("`RF` must be a ranger object.")
}
if (!inherits(trees, "RangerTrees")) {
stop("`trees` must be a `getTreeranger` `RangerTrees` object.")
}
num.trees <- RF$num.trees
if (num.trees != length(trees)) {
stop("Number of trees in ranger model `RF` does not match number of extracted trees in `trees`.")
}
ncat <- sapply(sapply(Xdata, levels), length) # determine number of categories (o for continuous variables)
names(ncat) <- colnames(Xdata)
if (any(ncat) > 0) {
Xdata[, which(ncat > 0)] <- sapply(Xdata[, which(ncat > 0)], unclass)
}
# variables to find surrogates (control file similar as in rpart)
controls <- list(maxsurrogate = as.integer(s), sur_agree = 0)
trees.surr <- parallel::mclapply(
X = mapply(list, .a = trees, .b = RF$inbag.counts, SIMPLIFY = FALSE),
FUN = getSurrgate2,
mc.cores = num.threads,
mc.preschedule = preschedule.threads,
Xdata = Xdata,
controls = controls,
s = s,
ncat = ncat
)
class(trees.surr) <- c(class(trees), "SurrogateTrees")