Title: | Surrogate-Assisted Feature Extraction |
---|---|
Description: | Provides a model agnostic tool for white-box model trained on features extracted from a black-box model. For more information see: Gosiewska et al. (2020) <doi:10.1016/j.dss.2021.113556>. |
Authors: | Alicja Gosiewska [aut, cre], Anna Gierlak [aut], Przemyslaw Biecek [aut, ths], Michal Burdukiewicz [ctb] |
Maintainer: | Alicja Gosiewska <[email protected]> |
License: | GPL-3 |
Version: | 0.1.4 |
Built: | 2025-01-07 03:30:36 UTC |
Source: | https://github.com/modeloriented/rsafe |
Datasets apartments
and apartmentsTest
are artificial,
generated from the same model.
Structure of the dataset is copied from real dataset from PBImisc
package,
but they were generated in a way to mimic effect of Anscombe quartet for complex black box models.
data(apartments)
data(apartments)
a data frame with 1000 rows and 6 columns
m2.price - price per square meter
surface - apartment area in square meters
no.rooms - number of rooms (correlated with surface)
district - district in which apartment is located, factor with 10 levels (Bemowo, Bielany, Mokotow, Ochota, Praga, Srodmiescie, Ursus, Ursynow, Wola, Zoliborz)
floor - floor
construction.year - construction year
A dataset from Kaggle competition Human Resources Analytics. https://www.kaggle.com/
A data frame with 14999 rows and 10 variables
satisfaction_level Level of satisfaction (0-1)
last_evaluation Time since last performance evaluation (in Years)
number_project Number of projects completed while at work
average_monthly_hours Average monthly hours at workplace
time_spend_company Number of years spent in the company
work_accident Whether the employee had a workplace accident
left Whether the employee left the workplace or not (1 or 0) Factor
promotion_last_5years Whether the employee was promoted in the last five years
sales Department in which they work for
salary Relative level of salary (high)
Dataset HR-analytics from https://www.kaggle.com
Plotting Transformations of the SAFE Extractor Object
## S3 method for class 'safe_extractor' plot(x, ..., variable = NULL)
## S3 method for class 'safe_extractor' plot(x, ..., variable = NULL)
x |
safe_extractor object containing information about variables transformations created with safe_extraction() function |
... |
other parameters |
variable |
character, name of the variable to be plotted |
a plot object
Printing Summary of the SAFE Extractor Object
## S3 method for class 'safe_extractor' print(x, ..., variable = NULL)
## S3 method for class 'safe_extractor' print(x, ..., variable = NULL)
x |
safe_extractor object containing information about variables transformations created with safe_extraction() function |
... |
other parameters |
variable |
character, name of the variable to be plotted. If this argument is not specified then transformations for all variables are printed |
No return value, prints the structure of the object
The safe_extraction() function creates a SAFE-extractor object which may be used later for surrogate feature extraction.
safe_extraction( explainer, response_type = "ale", grid_points = 50, N = 200, penalty = "MBIC", nquantiles = 10, no_segments = 2, method = "complete", B = 500, collapse = "_", interactions = FALSE, inter_param = 0.25, inter_threshold = 0.25, verbose = TRUE )
safe_extraction( explainer, response_type = "ale", grid_points = 50, N = 200, penalty = "MBIC", nquantiles = 10, no_segments = 2, method = "complete", B = 500, collapse = "_", interactions = FALSE, inter_param = 0.25, inter_threshold = 0.25, verbose = TRUE )
explainer |
DALEX explainer created with explain() function |
response_type |
character, type of response to be calculated, one of: "pdp", "ale". If features are uncorrelated, one can use "pdp" type - otherwise "ale" is strongly recommended. |
grid_points |
number of points on x-axis used for creating the PD/ALE plot, default 50 |
N |
number of observations from the dataset used for creating the PD/ALE plot, default 200 |
penalty |
penalty for introducing another changepoint, one of "AIC", "BIC", "SIC", "MBIC", "Hannan-Quinn" or numeric non-negative value |
nquantiles |
the number of quantiles used in integral approximation |
no_segments |
numeric, a number of segments variable is to be divided into in case of founding no breakpoints |
method |
the agglomeration method to be used in hierarchical clustering, one of: "ward.D", "ward.D2", "single", "complete", "average", "mcquitty", "median", "centroid" |
B |
number of reference datasets used to calculate gap statistics |
collapse |
a character string to separate original levels while combining them to the new one |
interactions |
logical, if interactions between variables are to be taken into account |
inter_param |
numeric, a positive value indicating which of single observation non-additive effects are to be regarded as significant, the higher value the higher non-additive effect has to be to be taken into account |
inter_threshold |
numeric, a value from |
verbose |
logical, if progress bar is to be printed |
safe_extractor object containing information about variables transformation
safely_transform_categorical
, safely_transform_continuous
, safely_detect_interactions
, safely_transform_data
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1], verbose = FALSE) safe_extractor <- safe_extraction(explainer_rf, grid_points = 30, N = 100, verbose = FALSE) print(safe_extractor) plot(safe_extractor, variable = "construction.year")
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1], verbose = FALSE) safe_extractor <- safe_extraction(explainer_rf, grid_points = 30, N = 100, verbose = FALSE) print(safe_extractor) plot(safe_extractor, variable = "construction.year")
The safely_detect_changepoints() function calculates the optimal positioning and number of changepoints for given data and penalty. It uses a PELT algorithm with a nonparametric cost function based on the empirical distribution. The implementation is inspired by the code available on https://github.com/rkillick/changepoint.
safely_detect_changepoints(data, penalty = "MBIC", nquantiles = 10)
safely_detect_changepoints(data, penalty = "MBIC", nquantiles = 10)
data |
a vector within which you wish to find changepoints |
penalty |
penalty for introducing another changepoint, one of "AIC", "BIC", "SIC", "MBIC", "Hannan-Quinn" or numeric non-negative value |
nquantiles |
the number of quantiles used in integral approximation |
a vector of optimal changepoint positions (last observations of each segment)
library(rSAFE) data <- rep(c(2,7), each=4) safely_detect_changepoints(data) set.seed(123) data <- c(rnorm(15, 0), rnorm(20, 2), rnorm(30, 8)) safely_detect_changepoints(data) safely_detect_changepoints(data, penalty = 25)
library(rSAFE) data <- rep(c(2,7), each=4) safely_detect_changepoints(data) set.seed(123) data <- c(rnorm(15, 0), rnorm(20, 2), rnorm(30, 8)) safely_detect_changepoints(data) safely_detect_changepoints(data, penalty = 25)
The safely_detect_interactions() function detects second-order interactions based on predictions made by a surrogate model. For each pair of features it performs values permutation in order to evaluate their non_additive effect.
safely_detect_interactions( explainer, inter_param = 0.5, inter_threshold = 0.5, verbose = TRUE )
safely_detect_interactions( explainer, inter_param = 0.5, inter_threshold = 0.5, verbose = TRUE )
explainer |
DALEX explainer created with explain() function |
inter_param |
numeric, a positive value indicating which of single observation non-additive effects are to be regarded as significant, the higher value the higher non-additive effect has to be to be taken into account |
inter_threshold |
numeric, a value from |
verbose |
logical, if progress bar is to be printed |
dataframe object containing interactions effects greater than or equal to the specified inter_threshold
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safely_detect_interactions(explainer_rf, inter_param = 0.25, inter_threshold = 0.2, verbose = TRUE)
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safely_detect_interactions(explainer_rf, inter_param = 0.25, inter_threshold = 0.2, verbose = TRUE)
The safely_select_variables() function selects variables from dataset returned by safely_transform_data() function. For each original variable exactly one variable is chosen
either original one or transformed one. The choice is based on the AIC value for linear model (regression) or logistic regression (classification).
safely_select_variables( safe_extractor, data, y = NULL, which_y = NULL, class_pred = NULL, verbose = TRUE )
safely_select_variables( safe_extractor, data, y = NULL, which_y = NULL, class_pred = NULL, verbose = TRUE )
safe_extractor |
object containing information about variables transformations created with safe_extraction() function |
data |
data, original dataset or the one returned by safely_transform_data() function. If data do not contain transformed variables then transformation is done inside this function using 'safe_extractor' argument. Data may contain response variable or not - if it does then 'which_y' argument must be given, otherwise 'y' argument should be provided. |
y |
vector of responses, must be given if data does not contain it |
which_y |
numeric or character (optional), must be given if data contains response values |
class_pred |
numeric or character, used only in multi-classification problems. If response vector has more than two levels, then 'class_pred' should indicate the class of interest which will denote failure - all other classes will stand for success. |
verbose |
logical, if progress bar is to be printed |
vector of variables names, selected based on AIC values
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safe_extractor <- safe_extraction(explainer_rf, verbose = FALSE) safely_select_variables(safe_extractor, data, which_y = "m2.price", verbose = FALSE)
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safe_extractor <- safe_extraction(explainer_rf, verbose = FALSE) safely_select_variables(safe_extractor, data, which_y = "m2.price", verbose = FALSE)
The safely_transform_categorical() function calculates a transformation function for the categorical variable using predictions obtained from black box model and hierarchical clustering. The gap statistic criterion is used to determine the optimal number of clusters.
safely_transform_categorical( explainer, variable, method = "complete", B = 500, collapse = "_" )
safely_transform_categorical( explainer, variable, method = "complete", B = 500, collapse = "_" )
explainer |
DALEX explainer created with explain() function |
variable |
a feature for which the transformation function is to be computed |
method |
the agglomeration method to be used in hierarchical clustering, one of: "ward.D", "ward.D2", "single", "complete", "average", "mcquitty", "median", "centroid" |
B |
number of reference datasets used to calculate gap statistics |
collapse |
a character string to separate original levels while combining them to the new one |
list of information on the transformation of given variable
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safely_transform_categorical(explainer_rf, "district")
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safely_transform_categorical(explainer_rf, "district")
The safely_transform_continuous() function calculates a transformation function for the continuous variable using a PD/ALE plot obtained from black box model.
safely_transform_continuous( explainer, variable, response_type = "ale", grid_points = 50, N = 200, penalty = "MBIC", nquantiles = 10, no_segments = 2 )
safely_transform_continuous( explainer, variable, response_type = "ale", grid_points = 50, N = 200, penalty = "MBIC", nquantiles = 10, no_segments = 2 )
explainer |
DALEX explainer created with explain() function |
variable |
a feature for which the transformation function is to be computed |
response_type |
character, type of response to be calculated, one of: "pdp", "ale". If features are uncorrelated, one can use "pdp" type - otherwise "ale" is strongly recommended. |
grid_points |
number of points on x-axis used for creating the PD/ALE plot, default 50 |
N |
number of observations from the dataset used for creating the PD/ALE plot, default 200 |
penalty |
penalty for introducing another changepoint, one of "AIC", "BIC", "SIC", "MBIC", "Hannan-Quinn" or numeric non-negative value |
nquantiles |
the number of quantiles used in integral approximation |
no_segments |
numeric, a number of segments variable is to be divided into in case of founding no breakpoints |
list of information on the transformation of given variable
safe_extraction
, safely_detect_changepoints
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safely_transform_continuous(explainer_rf, "construction.year")
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safely_transform_continuous(explainer_rf, "construction.year")
The safely_transform_data() function creates new variables in dataset using safe_extractor object.
safely_transform_data(safe_extractor, data, verbose = TRUE)
safely_transform_data(safe_extractor, data, verbose = TRUE)
safe_extractor |
object containing information about variables transformations created with safe_extraction() function |
data |
data for which features are to be transformed |
verbose |
logical, if progress bar is to be printed |
data with extra columns containing newly created variables
safe_extraction
, safely_select_variables
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safe_extractor <- safe_extraction(explainer_rf, verbose = FALSE) safely_transform_data(safe_extractor, data, verbose = FALSE)
library(DALEX) library(randomForest) library(rSAFE) data <- apartments[1:500,] set.seed(111) model_rf <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = data) explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1]) safe_extractor <- safe_extraction(explainer_rf, verbose = FALSE) safely_transform_data(safe_extractor, data, verbose = FALSE)