Simulated data, real problem

Simulated data

Let’s consider a following problem, the model is defined as

y = x1 * x2 + x2

But x1 and x2 are correlated. How XAI methods work for such model?

# predict function for the model
the_model_predict <- function(m, x) {
 x$x1 * x$x2 + x$x2
}

# correlated variables 
N <- 50
set.seed(1)
x1 <- runif(N, -5, 5)
x2 <- x1 + runif(N)/100
df <- data.frame(x1, x2)

Explainer for the models

In fact this model is defined by the predict function the_model_predict. So it does not matter what is in the first argument of the explain function.

library("DALEX")
explain_the_model <- explain(1,
                      data = df,
                      predict_function = the_model_predict)
#> Preparation of a new explainer is initiated
#>   -> model label       :  numeric  (  default  )
#>   -> data              :  50  rows  2  cols 
#>   -> target variable   :  not specified! (  WARNING  )
#>   -> predict function  :  the_model_predict 
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package Model of class: numeric package unrecognized , ver. Unknown , task regression (  default  ) 
#>   -> model_info        :  Model info detected regression task but 'y' is a NULL .  (  WARNING  )
#>   -> model_info        :  By deafult regressions tasks supports only numercical 'y' parameter. 
#>   -> model_info        :  Consider changing to numerical vector.
#>   -> model_info        :  Otherwise I will not be able to calculate residuals or loss function.
#>   -> predicted values  :  numerical, min =  -0.1726853 , mean =  7.70239 , max =  29.16158  
#>   -> residual function :  difference between y and yhat (  default  )
#>   A new explainer has been created!

Ceteris paribus

Use the ceteris_paribus() function to see Ceteris Paribus profiles. Clearly it’s not an additive model, as the effect of x1 depends on x2.

library("ingredients")
library("ggplot2")

sample_rows <- data.frame(x1 = -5:5,
                          x2 = -5:5)

cp_model <- ceteris_paribus(explain_the_model, sample_rows)
plot(cp_model) +
  show_observations(cp_model) +
  ggtitle("Ceteris Paribus profiles")

Dependence profiles

Lets try Partial Dependence profiles, Conditional Dependence profiles and Accumulated Local profiles. For the last two we can try different smoothing factors

pd_model <- partial_dependence(explain_the_model, variables = c("x1", "x2"))
pd_model$`_label_` = "PDP"

cd_model <- conditional_dependence(explain_the_model, variables = c("x1", "x2"))
cd_model$`_label_` = "CDP 0.25"

ad_model <- accumulated_dependence(explain_the_model, variables = c("x1", "x2"))
ad_model$`_label_` = "ALE 0.25"

plot(ad_model, cd_model, pd_model) +
  ggtitle("Feature effects - PDP, CDP, ALE")

cd_model_1 <- conditional_dependence(explain_the_model, variables = c("x1", "x2"), span = 0.1)
cd_model_1$`_label_` = "CDP 0.1"

cd_model_5 <- conditional_dependence(explain_the_model, variables = c("x1", "x2"), span = 0.5)
cd_model_5$`_label_` = "CDP 0.5"

ad_model_1 <- accumulated_dependence(explain_the_model, variables = c("x1", "x2"), span = 0.5)
ad_model_1$`_label_` = "ALE 0.1"

ad_model_5 <- accumulated_dependence(explain_the_model, variables = c("x1", "x2"), span = 0.5)
ad_model_5$`_label_` = "ALE 0.5"

plot(ad_model, cd_model, pd_model, cd_model_1, cd_model_5, ad_model_1, ad_model_5) +
  ggtitle("Feature effects - PDP, CDP, ALE")

Dependence profiles in groups

And now, let’s see how the grouping factor works

# add grouping variable
df$x3 <- factor(sign(df$x2))
# update the data argument
explain_the_model$data = df

# PDP in groups
pd_model_groups <- partial_dependence(explain_the_model, 
                                      variables = c("x1", "x2"), 
                                      groups = "x3")
plot(pd_model_groups) +
  ggtitle("Partial Dependence")

# ALE in groups
ad_model_groups <- accumulated_dependence(explain_the_model, 
                                      variables = c("x1", "x2"), 
                                      groups = "x3")
plot(ad_model_groups) +
  ggtitle("Accumulated Local")

# CDP in groups
cd_model_groups <- conditional_dependence(explain_the_model, 
                                      variables = c("x1", "x2"), 
                                      groups = "x3")
plot(cd_model_groups) +
  ggtitle("Conditional Dependence")

Session info

sessionInfo()
#> R version 4.4.2 (2024-10-31)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 24.04.1 LTS
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
#> LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so;  LAPACK version 3.12.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=C              
#>  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
#>  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
#> 
#> time zone: Etc/UTC
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ggplot2_3.5.1     ranger_0.17.0     ingredients_2.3.1 DALEX_2.5.1      
#> [5] rmarkdown_2.29   
#> 
#> loaded via a namespace (and not attached):
#>  [1] vctrs_0.6.5       cli_3.6.3         knitr_1.49        rlang_1.1.4      
#>  [5] xfun_0.49         jsonlite_1.8.9    labeling_0.4.3    glue_1.8.0       
#>  [9] colorspace_2.1-1  buildtools_1.0.0  htmltools_0.5.8.1 maketools_1.3.1  
#> [13] sys_3.4.3         sass_0.4.9        scales_1.3.0      grid_4.4.2       
#> [17] tibble_3.2.1      evaluate_1.0.1    munsell_0.5.1     jquerylib_0.1.4  
#> [21] fastmap_1.2.0     yaml_2.3.10       lifecycle_1.0.4   compiler_4.4.2   
#> [25] Rcpp_1.0.13-1     pkgconfig_2.0.3   farver_2.1.2      lattice_0.22-6   
#> [29] digest_0.6.37     R6_2.5.1          pillar_1.10.0     magrittr_2.0.3   
#> [33] Matrix_1.7-1      bslib_0.8.0       withr_3.0.2       tools_4.4.2      
#> [37] gtable_0.3.6      cachem_1.1.0