Let’s see an example for DALEX
package for
classification models for the survival problem for Titanic dataset. Here
we are using a dataset titanic_imputed
avaliable in the
DALEX
package. Note that this data was copied from the
stablelearner
package and changed for practicality.
#> gender age class embarked fare sibsp parch survived
#> 1 male 42 3rd Southampton 7.11 0 0 0
#> 2 male 13 3rd Southampton 20.05 0 2 0
#> 3 male 16 3rd Southampton 20.05 1 1 0
#> 4 female 39 3rd Southampton 20.05 1 1 1
#> 5 female 16 3rd Southampton 7.13 0 0 1
#> 6 male 25 3rd Southampton 7.13 0 0 1
Ok, now it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
fare + sibsp + parch,
data = titanic_imputed, probability = TRUE)
model_titanic_rf
#> Ranger result
#>
#> Call:
#> ranger(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic_imputed, probability = TRUE)
#>
#> Type: Probability estimation
#> Number of trees: 500
#> Sample size: 2207
#> Number of independent variables: 7
#> Mtry: 2
#> Target node size: 10
#> Variable importance mode: none
#> Splitrule: gini
#> OOB prediction error (Brier s.): 0.1420499
The third step (it’s optional but useful) is to create a
DALEX
explainer for random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed[,-8],
y = titanic_imputed[,8],
label = "Random Forest")
#> Preparation of a new explainer is initiated
#> -> model label : Random Forest
#> -> data : 2207 rows 7 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.ranger will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package ranger , ver. 0.16.0 , task classification ( default )
#> -> predicted values : numerical, min = 0.01180482 , mean = 0.3221012 , max = 0.9907698
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.7742579 , mean = 5.553535e-05 , max = 0.8805021
#> A new explainer has been created!
Use the feature_importance()
explainer to present
importance of particular features. Note that
type = "difference"
normalizes dropouts, and now they all
start in 0.
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.3405669 Random Forest
#> 2 sibsp 0.3513237 Random Forest
#> 3 parch 0.3517331 Random Forest
#> 4 embarked 0.3529161 Random Forest
#> 5 age 0.3764347 Random Forest
#> 6 fare 0.3836626 Random Forest
As we see the most important feature is gender
. Next
three importnat features are class
, age
and
fare
. Let’s see the link between model response and these
features.
Such univariate relation can be calculated with
partial_dependence()
.
Kids 5 years old and younger have much higher survival probability.
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 fare Random Forest 0.0000000 0.3622347 0
#> 2 age Random Forest 0.1666667 0.5196749 0
#> 3 age Random Forest 2.0000000 0.5515232 0
#> 4 age Random Forest 4.0000000 0.5624471 0
#> 5 fare Random Forest 6.1793080 0.3122499 0
#> 6 age Random Forest 7.0000000 0.5155496 0
Let’s see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
First Ceteris Paribus Profiles for numerical variables
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- ceteris_paribus(explain_titanic_rf, new_passanger)
plot(sp_rf) +
show_observations(sp_rf)
And for selected categorical variables. Note, that sibsp is numerical but here is presented as a categorical variable.
It looks like the most important feature for this passenger is
age
and sex
. After all his odds for survival
are higher than for the average passenger. Mainly because of the young
age and despite of being a male.
passangers <- select_sample(titanic, n = 100)
sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
clust_rf <- cluster_profiles(sp_rf, k = 3)
head(clust_rf)
#> Top profiles :
#> _vname_ _label_ _x_ _cluster_ _yhat_ _ids_
#> 1 fare Random Forest_1 0.0000000 1 0.2349731 0
#> 2 parch Random Forest_1 0.0000000 1 0.1673656 0
#> 3 sibsp Random Forest_1 0.0000000 1 0.1710698 0
#> 4 age Random Forest_1 0.1666667 1 0.4458216 0
#> 5 parch Random Forest_1 1.0000000 1 0.2444645 0
#> 6 sibsp Random Forest_1 1.0000000 1 0.1480347 0
#> R version 4.4.1 (2024-06-14)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 24.04.1 LTS
#>
#> Matrix products: default
#> BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
#> LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
#>
#> locale:
#> [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
#> [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
#> [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
#> [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
#> [9] LC_ADDRESS=C LC_TELEPHONE=C
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
#>
#> time zone: Etc/UTC
#> tzcode source: system (glibc)
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ggplot2_3.5.1 ranger_0.16.0 ingredients_2.3.1 DALEX_2.5.1
#> [5] rmarkdown_2.28
#>
#> loaded via a namespace (and not attached):
#> [1] Matrix_1.7-1 gtable_0.3.6 jsonlite_1.8.9 highr_0.11
#> [5] compiler_4.4.1 Rcpp_1.0.13 jquerylib_0.1.4 scales_1.3.0
#> [9] yaml_2.3.10 fastmap_1.2.0 lattice_0.22-6 R6_2.5.1
#> [13] labeling_0.4.3 knitr_1.48 tibble_3.2.1 maketools_1.3.1
#> [17] munsell_0.5.1 bslib_0.8.0 pillar_1.9.0 rlang_1.1.4
#> [21] utf8_1.2.4 cachem_1.1.0 xfun_0.48 sass_0.4.9
#> [25] sys_3.4.3 cli_3.6.3 withr_3.0.2 magrittr_2.0.3
#> [29] digest_0.6.37 grid_4.4.1 lifecycle_1.0.4 vctrs_0.6.5
#> [33] evaluate_1.0.1 glue_1.8.0 farver_2.1.2 buildtools_1.0.0
#> [37] fansi_1.0.6 colorspace_2.1-1 tools_4.4.1 pkgconfig_2.0.3
#> [41] htmltools_0.5.8.1