Let’s consider a following problem, the model is defined as
y = x1 * x2 + x2
But x1 and x2 are correlated. How XAI methods work for such model?
In fact this model is defined by the predict function
the_model_predict
. So it does not matter what is in the
first argument of the explain
function.
#> Preparation of a new explainer is initiated
#> -> model label : numeric ( default )
#> -> data : 50 rows 2 cols
#> -> target variable : not specified! ( WARNING )
#> -> predict function : the_model_predict
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package Model of class: numeric package unrecognized , ver. Unknown , task regression ( default )
#> -> model_info : Model info detected regression task but 'y' is a NULL . ( WARNING )
#> -> model_info : By deafult regressions tasks supports only numercical 'y' parameter.
#> -> model_info : Consider changing to numerical vector.
#> -> model_info : Otherwise I will not be able to calculate residuals or loss function.
#> -> predicted values : numerical, min = -0.1726853 , mean = 7.70239 , max = 29.16158
#> -> residual function : difference between y and yhat ( default )
#> A new explainer has been created!
Use the ceteris_paribus()
function to see Ceteris
Paribus profiles. Clearly it’s not an additive model, as the effect of
x1 depends on x2.
Lets try Partial Dependence profiles, Conditional Dependence profiles and Accumulated Local profiles. For the last two we can try different smoothing factors
pd_model <- partial_dependence(explain_the_model, variables = c("x1", "x2"))
pd_model$`_label_` = "PDP"
cd_model <- conditional_dependence(explain_the_model, variables = c("x1", "x2"))
cd_model$`_label_` = "CDP 0.25"
ad_model <- accumulated_dependence(explain_the_model, variables = c("x1", "x2"))
ad_model$`_label_` = "ALE 0.25"
plot(ad_model, cd_model, pd_model) +
ggtitle("Feature effects - PDP, CDP, ALE")
cd_model_1 <- conditional_dependence(explain_the_model, variables = c("x1", "x2"), span = 0.1)
cd_model_1$`_label_` = "CDP 0.1"
cd_model_5 <- conditional_dependence(explain_the_model, variables = c("x1", "x2"), span = 0.5)
cd_model_5$`_label_` = "CDP 0.5"
ad_model_1 <- accumulated_dependence(explain_the_model, variables = c("x1", "x2"), span = 0.5)
ad_model_1$`_label_` = "ALE 0.1"
ad_model_5 <- accumulated_dependence(explain_the_model, variables = c("x1", "x2"), span = 0.5)
ad_model_5$`_label_` = "ALE 0.5"
plot(ad_model, cd_model, pd_model, cd_model_1, cd_model_5, ad_model_1, ad_model_5) +
ggtitle("Feature effects - PDP, CDP, ALE")
And now, let’s see how the grouping factor works
# add grouping variable
df$x3 <- factor(sign(df$x2))
# update the data argument
explain_the_model$data = df
# PDP in groups
pd_model_groups <- partial_dependence(explain_the_model,
variables = c("x1", "x2"),
groups = "x3")
plot(pd_model_groups) +
ggtitle("Partial Dependence")
#> R version 4.4.2 (2024-10-31)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 24.04.1 LTS
#>
#> Matrix products: default
#> BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
#> LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
#>
#> locale:
#> [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
#> [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
#> [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
#> [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
#> [9] LC_ADDRESS=C LC_TELEPHONE=C
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
#>
#> time zone: Etc/UTC
#> tzcode source: system (glibc)
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ggplot2_3.5.1 ranger_0.17.0 ingredients_2.3.1 DALEX_2.5.1
#> [5] rmarkdown_2.29
#>
#> loaded via a namespace (and not attached):
#> [1] vctrs_0.6.5 cli_3.6.3 knitr_1.49 rlang_1.1.4
#> [5] xfun_0.49 jsonlite_1.8.9 labeling_0.4.3 glue_1.8.0
#> [9] colorspace_2.1-1 buildtools_1.0.0 htmltools_0.5.8.1 maketools_1.3.1
#> [13] sys_3.4.3 sass_0.4.9 scales_1.3.0 grid_4.4.2
#> [17] tibble_3.2.1 evaluate_1.0.1 munsell_0.5.1 jquerylib_0.1.4
#> [21] fastmap_1.2.0 yaml_2.3.10 lifecycle_1.0.4 compiler_4.4.2
#> [25] Rcpp_1.0.13-1 pkgconfig_2.0.3 farver_2.1.2 lattice_0.22-6
#> [29] digest_0.6.37 R6_2.5.1 pillar_1.10.0 magrittr_2.0.3
#> [33] Matrix_1.7-1 bslib_0.8.0 withr_3.0.2 tools_4.4.2
#> [37] gtable_0.3.6 cachem_1.1.0