Skip to content
Snippets Groups Projects
Verified Commit 35db4a42 authored by Gärber, Florian's avatar Gärber, Florian
Browse files

feat: Add `RandomForestSurrogates` function

parent 4529d2c1
No related branches found
No related tags found
No related merge requests found
# Generated by roxygen2: do not edit by hand
export(RandomForestSurrogates)
export(addLayer)
export(addSurrogates)
export(build.clusters)
......
# RFSurrogates (development version)
## New Features
* Added `RandomForestSurrogates()`.
* This functions aims to replace the first section of the variable selection and relation functions by creating a single reusable object which contains the random forest `ranger::ranger()` model, as well as the `trees` list with layers and surrogates added.
* Returns a `RFSurrogates` object, which serves as the base object for later analysis.
* Additional `...` params are passed directly to `ranger::ranger`.
* `s.pct` is a helper for calculating the number of surrogates as a fraction of number of variables (Default: 0.01). `s` can be set to overwrite this default.
* `mtry` supports the following values:
* One of the documented `string` values, which will cause the `mtry` passed to `ranger::ranger()` to be a function accepting the number of variables, and returning the specific transformation after flooring the result.
* A `function` which takes the number of variables as its first and only param, and returns the value of `mtry`.
* A `numeric` value for `mtry`.
* The default is `"^3/4"`.
* `type` also uses `match.arg()` and still defaults to `"regression"`.
* `num.threads` transparently defaults to `parallel::detectCores()`.
* `permutate` will, if set to `TRUE`, apply random permutation to the data in each feature. (This is used in permutation importance approaches.)
* `seed` is now a strongly recommended optional parameter (issuing a warning whenever it is not set).
* Setting `seed` will cause a call to `set.seed()` when permutating. It is also used as the `seed` param of the `ranger::ranger()` call.
* Requiring `seed` as a function parameter is preferred because it does not rely on global, non-reproducible state of the random number generator, if it was not seeded immediately before the function call.
* The inner call to `ranger::ranger()` includes the following defaults:
* `keep.inbag = TRUE`
* `respect.unordered.factors = "partition"`
* Data is passed as a data.frame with the special column `y`, and the optional special column `status` for survival forests.
* `x` must not contain the column names `y` or `status`, as this may lead to unexpected behavior.
* In general, input parameters are more strictly validated.
* The function uses `num.threads` to also parallelize creating the list of trees with layers.
## Changes
* `var.select.smd()`, `var.select.md()`, `var.relations()`, `var.relations.mfi()`: Made several improvements to developer experience:
* `create.forest` now defaults to `is.null(forest)`, so it will automatically be `TRUE` if no forest is provided, and `FALSE` otherwise.
* `x` is no longer required if `create.forest` is `FALSE`.
......
#' Create a random forest with surrogates.
#'
#' @inheritParams ranger::ranger
#'
#' @param s.pct,s Number of surrogate splits.
#' This can be defined either by setting `s.pct` to a number between
#' 0 and 1, or providing an exact value for `s`.
#' - `s.pct`: Percentage of variables to use for `s`. (Default: 0.01)
#' - `s`: Number of surrogate splits. (Default: Number of variables
#' multiplied by `s.pct`, which defaults to 0.01; If `s.pct` is
#' less than or equal to zero, or greater than 1: 0.01 is used instead.)
#'
#' @param mtry Number of variables to possibly split at in each node.
#' Default is the (rounded down) number of variables to the power
#' of three quarters (Ishwaran, 2011).
#' Alternatively, a single argument function returning an integer,
#' given the number of independent variables.
#'
#' @param type The type of random forest to create with ranger.
#' One of "regression" (Default), "classification" or "survival".
#'
#' @param min.node.size Minimal node size to split at. (Default: 1)
#'
#' @param permutate Enable to permutate `x` for [`MutualForestImpact`] (Default: FALSE).
#'
#' @param seed RNG seed. It is strongly recommended that you set this value.
#'
#' @param ... Other params passed on to [ranger::ranger()].
#'
#' @returns A RFSurrogates S3 object.
#' * `trees`: List of all trees with surrogate analysis. (Class: `SurrogateTrees`, `LayerTrees`, `RangerTrees`)
#' * `ranger`: [ranger::ranger] model used to obtain the trees.
#' * `s` = s: The `s` parameter.
#'
#' @keywords prep
#' @export
RandomForestSurrogates <- function(
x = NULL, y = NULL,
s.pct = 0.01,
s = ceiling(ncol(x) * ifelse(s.pct > 0 && s.pct <= 1, s.pct, 0.01)),
mtry = c("^3/4", "sqrt", "0.5"),
type = c("regression", "classification", "survival"),
status = NULL,
num.trees = 500,
num.threads = parallel::detectCores(),
min.node.size = 1,
permutate = FALSE,
seed = NULL,
...) {
if (length(y) != nrow(x)) {
stop(paste0("Different numbers of response variables and observations.\nFound: nrow(x) = ", nrow(x), ", length(y) = ", length(y), ", expected them to be equal."))
}
if (any(is.na(x))) {
stop("`x` contains missing values.")
}
if (is.null(seed)) {
warning("`seed` was not set. Your results may not be reproducible.")
}
if (permutate) {
if (!is.null(seed)) {
set.seed(seed)
}
x <- permutate_feature_values(x)
}
mtry_match_args <- c("^3/4", "sqrt", "0.5")
if (is.null(mtry) || is.character(mtry)) {
mtry <- match.arg(mtry, mtry_match_args)
mtry <- switch(mtry,
`^3/4` = function(nvar) floor((nvar)^(3 / 4)),
`sqrt` = function(nvar) floor(sqrt(nvar)),
`0.5` = function(nvar) floor(nvar * 0.5)
)
}
if (!is.numeric(mtry) && !is.function(mtry)) {
stop(paste0("`mtry` must be one of ", paste(paste0("\"", mtry_match_args, "\""), collapse = ", "), ", a numeric or a function."))
}
nvar <- ncol(x)
if (s > (nvar - 1)) {
warning(paste0("`s` was set to the maximum number that is reasonable (set to ", nvar - 1, ", was ", s, ")."))
s <- nvar - 1
}
type_match_args <- c("regression", "classification", "survival")
type <- match.arg(type, type_match_args)
if (type == "classification") {
y <- as.factor(y)
if (length(levels(y)) > 15) {
warning(paste0("Found ", length(levels(y)), " levels in `y`. Is `type = \"classification\"` the right choice?"))
}
}
if (type == "regression" && inherits(y, "factor")) {
stop("`y` must not be a factor with `type = \"regression\"`.")
}
if (any(c("y", "status") %in% colnames(x))) {
stop("`x` must not contain columns named `y` or `status`.")
}
if (type == "survival") {
if (length(y) != length(status)) {
stop(paste0("Different numbers of response variables and status variables.\nFound: length(status) = ", length(status), ", length(y) = ", length(y), ", expected them to be equal."))
}
if (!all(status %in% c(0, 1))) {
stop("`status` must contain only 1 or 0. Use 1 for event and 0 for censoring.")
}
data <- data.frame(x, y, status)
RF <- ranger::ranger(
data = data,
dependent.variable.name = "y",
status.variable.name = "status",
mtry = mtry,
keep.inbag = TRUE,
respect.unordered.factors = "partition",
num.trees = num.trees,
num.threads = num.threads,
min.node.size = min.node.size,
...
)
} else if (type == "classification" || type == "regression") {
data <- data.frame(x, y)
RF <- ranger::ranger(
data = data,
dependent.variable.name = "y",
mtry = mtry,
keep.inbag = TRUE,
respect.unordered.factors = "partition",
classification = type == "classification",
num.trees = num.trees,
num.threads = num.threads,
min.node.size = min.node.size,
...
)
} else {
stop(paste0("`type` must be one of ", paste(paste0("\"", type_match_args, "\""), collapse = ", "), "."))
}
trees <- getSurrogateTrees(
RF = RF,
s = s,
x = x,
num.threads = num.threads
)
result <- list(
trees = trees,
ranger = RF,
s = s
)
class(result) <- "RFSurrogates"
return(result)
}
#' @returns A `SurrogateTrees`, `LayerTrees`, `RangerTrees` object.
#'
#' @keywords internal
getSurrogateTrees <- function(RF, s, x, num.threads = parallel::detectCores()) {
if (!inherits(RF, "ranger")) {
stop("`RF` must be a `ranger` object.")
}
addSurrogates(
RF = RF,
trees = getTreeranger(
RF = RF,
add_layer = TRUE,
num.threads = num.threads
),
s = s,
Xdata = x,
num.threads = num.threads
)
}
#' @keywords internal
permutate_feature_values <- function(x) {
perm_names <- paste(colnames(x), "perm", sep = "_")
x <- data.frame(lapply(1:ncol(x), permute.variable, x = x))
colnames(x) <- perm_names
return(x)
}
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/RandomForestSurrogates.R
\name{RandomForestSurrogates}
\alias{RandomForestSurrogates}
\title{Create a random forest with surrogates.}
\usage{
RandomForestSurrogates(
x = NULL,
y = NULL,
s.pct = 0.01,
s = ceiling(ncol(x) * ifelse(s.pct > 0 && s.pct <= 1, s.pct, 0.01)),
mtry = c("^3/4", "sqrt", "0.5"),
type = c("regression", "classification", "survival"),
status = NULL,
num.trees = 500,
num.threads = parallel::detectCores(),
min.node.size = 1,
permutate = FALSE,
seed = NULL,
...
)
}
\arguments{
\item{x}{Predictor data (independent variables), alternative interface to data with formula or dependent.variable.name.}
\item{y}{Response vector (dependent variable), alternative interface to data with formula or dependent.variable.name. For survival use a \code{Surv()} object or a matrix with time and status.}
\item{s.pct, s}{Number of surrogate splits.
This can be defined either by setting \code{s.pct} to a number between
0 and 1, or providing an exact value for \code{s}.
\itemize{
\item \code{s.pct}: Percentage of variables to use for \code{s}. (Default: 0.01)
\item \code{s}: Number of surrogate splits. (Default: Number of variables
multiplied by \code{s.pct}, which defaults to 0.01; If \code{s.pct} is
less than or equal to zero, or greater than 1: 0.01 is used instead.)
}}
\item{mtry}{Number of variables to possibly split at in each node.
Default is the (rounded down) number of variables to the power
of three quarters (Ishwaran, 2011).
Alternatively, a single argument function returning an integer,
given the number of independent variables.}
\item{type}{The type of random forest to create with ranger.
One of "regression" (Default), "classification" or "survival".}
\item{num.trees}{Number of trees.}
\item{num.threads}{Number of threads. Default is number of CPUs available.}
\item{min.node.size}{Minimal node size to split at. (Default: 1)}
\item{permutate}{Enable to permutate \code{x} for \code{\link{MutualForestImpact}} (Default: FALSE).}
\item{seed}{RNG seed. It is strongly recommended that you set this value.}
\item{...}{Other params passed on to \code{\link[ranger:ranger]{ranger::ranger()}}.}
}
\value{
A RFSurrogates S3 object.
\itemize{
\item \code{trees}: List of all trees with surrogate analysis. (Class: \code{SurrogateTrees}, \code{LayerTrees}, \code{RangerTrees})
\item \code{ranger}: \link[ranger:ranger]{ranger::ranger} model used to obtain the trees.
\item \code{s} = s: The \code{s} parameter.
}
}
\description{
Create a random forest with surrogates.
}
\keyword{prep}
test_that("RFS", {
skip_on_ci()
data("SMD_example_data")
rfs <- RandomForestSurrogates(
x = SMD_example_data[, -1],
y = SMD_example_data[, 1],
num.trees = 50,
num.threads = 1,
seed = 42,
s = 3
)
expect(TRUE, "never")
})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment