Title: | Variable Importance via Oscillations |
---|---|
Description: | Provides an easy to calculate local variable importance measure based on Ceteris Paribus profile and global variable importance measure based on Partial Dependence Profiles. |
Authors: | Anna Kozak [aut, cre], Przemyslaw Biecek [aut, ths] |
Maintainer: | Anna Kozak <[email protected]> |
License: | GPL-2 |
Version: | 0.2.1 |
Built: | 2024-11-21 03:20:18 UTC |
Source: | https://github.com/modeloriented/vivo |
This function calculate candidate splits for each selected variable.
For numerical variables splits are calculated as percentiles
(in general uniform quantiles of the length grid_points
).
For all other variables splits are calculated as unique values.
calculate_variable_split(data, variables = colnames(data), grid_points = 101)
calculate_variable_split(data, variables = colnames(data), grid_points = 101)
data |
validation dataset. Is used to determine distribution of observations. |
variables |
names of variables for which splits shall be calculated |
grid_points |
number of points used for response path |
A named list with splits for selected variables
This function is a copy of calculate_varaible_split()
from ingredients
package with small change.
Przemyslaw Biecek
This function calculate an empirical density of raw data based on variable split from Ceteris Paribus profiles. Then calculated weight for values generated by DALEX::predict_profile()
, DALEX::individual_profile()
or ingredients::ceteris_paribus()
.
calculate_weight(profiles, data, variable_split)
calculate_weight(profiles, data, variable_split)
profiles |
|
data |
|
variable_split |
list generated by |
Return an weight based on empirical density.
library("DALEX", warn.conflicts = FALSE, quietly = TRUE) data(apartments) split <- vivo::calculate_variable_split(apartments, variables = colnames(apartments), grid_points = 101) library("randomForest", warn.conflicts = FALSE, quietly = TRUE) apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3) profiles <- predict_profile(explainer_rf, new_apartment) library("vivo") calculate_weight(profiles, data = apartments[, 2:5], variable_split = split)
library("DALEX", warn.conflicts = FALSE, quietly = TRUE) data(apartments) split <- vivo::calculate_variable_split(apartments, variables = colnames(apartments), grid_points = 101) library("randomForest", warn.conflicts = FALSE, quietly = TRUE) apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3) profiles <- predict_profile(explainer_rf, new_apartment) library("vivo") calculate_weight(profiles, data = apartments[, 2:5], variable_split = split)
This function calculate global importance measure.
global_variable_importance(profiles)
global_variable_importance(profiles)
profiles |
|
A data.frame
of the class global_variable_importance
.
It's a data.frame
with calculated global variable importance measure.
library("DALEX") data(apartments) library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) profiles <- model_profile(explainer_rf) library("vivo") global_variable_importance(profiles)
library("DALEX") data(apartments) library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) profiles <- model_profile(explainer_rf) library("vivo") global_variable_importance(profiles)
This function calculate local importance measure in eight variants. We obtain eight variants measure through the possible options of three parameters such as absolute_deviation
, point
and density
.
local_variable_importance( profiles, data, absolute_deviation = TRUE, point = TRUE, density = TRUE, grid_points = 101 )
local_variable_importance( profiles, data, absolute_deviation = TRUE, point = TRUE, density = TRUE, grid_points = 101 )
profiles |
|
data |
|
absolute_deviation |
logical parameter, if |
point |
logical parameter, if |
density |
logical parameter, if |
grid_points |
maximum number of points for profile calculations, the default values is 101, the same as in |
A data.frame
of the class local_variable_importance
.
It's a data.frame
with calculated local variable importance measure.
library("DALEX") data(apartments) library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3) profiles <- predict_profile(explainer_rf, new_apartment) library("vivo") local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = TRUE) local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = FALSE) local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = FALSE, density = TRUE)
library("DALEX") data(apartments) library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3) profiles <- predict_profile(explainer_rf, new_apartment) library("vivo") local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = TRUE) local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = FALSE) local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = FALSE, density = TRUE)
Function plot.global_importance plots global importance measure based on Partial Dependence profiles.
## S3 method for class 'global_importance' plot(x, ..., variables = NULL, type = NULL, title = "Variable importance")
## S3 method for class 'global_importance' plot(x, ..., variables = NULL, type = NULL, title = "Variable importance")
x |
object returned from |
... |
other object returned from |
variables |
if not |
type |
a character. How variables shall be plotted? Either "bars" (default) or "lines". |
title |
the plot's title, by default |
a ggplot2 object
library("DALEX") data(apartments) library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) profiles <- model_profile(explainer_rf) library("vivo") measure <- global_variable_importance(profiles) plot(measure)
library("DALEX") data(apartments) library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) profiles <- model_profile(explainer_rf) library("vivo") measure <- global_variable_importance(profiles) plot(measure)
Function plot.local_importance plots local importance measure based on Ceteris Paribus profiles.
## S3 method for class 'local_importance' plot( x, ..., variables = NULL, color = NULL, type = NULL, title = "Local variable importance" )
## S3 method for class 'local_importance' plot( x, ..., variables = NULL, color = NULL, type = NULL, title = "Local variable importance" )
x |
object returned from |
... |
other object returned from |
variables |
if not |
color |
a character. How to aggregated measure? Either "_label_method_" or "_label_model_". |
type |
a character. How variables shall be plotted? Either "bars" (default) or "lines". |
title |
the plot's title, by default |
a ggplot2 object
library("DALEX") data(apartments) library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3) profiles <- predict_profile(explainer_rf, new_apartment) library("vivo") measure1 <- local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = FALSE) plot(measure1) measure2 <- local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = TRUE) plot(measure1, measure2, color = "_label_method_", type = "lines")
library("DALEX") data(apartments) library("randomForest") apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms, data = apartments) explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5], y = apartmentsTest$m2.price) new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3) profiles <- predict_profile(explainer_rf, new_apartment) library("vivo") measure1 <- local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = FALSE) plot(measure1) measure2 <- local_variable_importance(profiles, apartments[,2:5], absolute_deviation = TRUE, point = TRUE, density = TRUE) plot(measure1, measure2, color = "_label_method_", type = "lines")