Skip to content
Snippets Groups Projects
Select Git revision
  • 477f6961ef9c435bb71123dedf257aeb4ec11016
  • main default protected
  • fix-17
  • gh-pages
  • round-parameter
  • fix-smd-selected-names
  • fn-rework
  • v0.3
  • v0.3.4-dev
  • v0.4.2
  • v0.4.1
  • v0.4.0
  • v0.3.4
  • v0.3.3
  • v0.3.2
15 results

addSurrogates.R

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    addSurrogates.R 5.82 KiB
    #' Add surrogate information to a tree list.
    #'
    #' This function adds surrogate variables and adjusted agreement values to a forest that was created by [getTreeranger].
    #'
    #' @param RF A [ranger::ranger] object which was created with `keep.inbag = TRUE`.
    #' @param trees List of trees created by [getTreeranger].
    #' @param s Predefined number of surrogate splits (it may happen that the actual number of surrogate splits differs in individual nodes).
    #' @param Xdata data without the dependent variable.
    #' @param num.threads (Default: [parallel::detectCores()]) Number of threads to spawn for parallelization.
    #' @param preschedule.threads (Default: TRUE) Passed as `mc.preschedule` to [parallel::mclapply()].
    #'
    #' @returns A list of trees.
    #' A list of trees containing of lists of nodes with the elements:
    #' * `nodeID`: ID of the respective node (important for left and right daughters in the next columns)
    #' * `leftdaughter`: ID of the left daughter of this node
    #' * `rightdaughter`: ID of the right daughter of this node
    #' * `splitvariable`: ID of the split variable
    #' * `splitpoint`: splitpoint of the split variable
    #' * `status`: `0` for terminal and `1` for non-terminal
    #' * `layer`: layer information (`0` means root node, `1` means 1 layer below root, etc)
    #' * `surrogate_i`: numbered surrogate variables (number depending on s)
    #' * `adj_i`: adjusted agreement of variable i
    #'
    #' @export
    addSurrogates <- function(
        RF,
        trees,
        s,
        Xdata,
        num.threads = parallel::detectCores(),
        preschedule.threads = TRUE
    ) {
      if (!inherits(RF, "ranger")) {
        stop("`RF` must be a ranger object.")
      }
    
      if (!inherits(trees, "RangerTrees")) {
        stop("`trees` must be a `getTreeranger` `RangerTrees` object.")
      }
    
      num.trees <- RF$num.trees
    
      if (num.trees != length(trees)) {
        stop("Number of trees in ranger model `RF` does not match number of extracted trees in `trees`.")
      }
    
      ncat <- sapply(sapply(Xdata, levels), length) # determine number of categories (o for continuous variables)
      names(ncat) <- colnames(Xdata)
    
      if (any(ncat) > 0) {
        Xdata[, which(ncat > 0)] <- sapply(Xdata[, which(ncat > 0)], unclass)
      }
    
      # variables to find surrogates (control file similar as in rpart)
      controls <- list(maxsurrogate = as.integer(s), sur_agree = 0)
    
      trees.surr <- parallel::mclapply(
        X = mapply(list, .a = trees, .b = RF$inbag.counts, SIMPLIFY = FALSE),
        FUN = getSurrgate2,
    
        mc.cores = num.threads,
        mc.preschedule = preschedule.threads,
    
        Xdata = Xdata,
        controls = controls,
        s = s,
        ncat = ncat
      )
    
      class(trees.surr) <- c(class(trees), "SurrogateTrees")