Global explanations with SurvSHAP(t)

This vignette demonstrates how to use the model_survshap() function can be used to calculate global explanations for survival models. These explanations can be used to assess variable importance, and can be plotted in different ways to focus on different aspects.

To create this explanation we follow the standard way of working with survex i.e. we create a model, and an explainer.

library(survex)
library(survival)
library(ranger)

vet <- survival::veteran

cph <- coxph(Surv(time, status) ~ ., data = vet, x = TRUE, model = TRUE)
exp <- explain(cph, data = vet[, -c(3,4)], y = Surv(vet$time, vet$status))
#> Preparation of a new explainer is initiated 
#>   -> model label       :  coxph (  default  ) 
#>   -> data              :  137  rows  6  cols 
#>   -> target variable   :  137  values ( 128 events and 9 censored ) 
#>   -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
#>   -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
#>   -> predict function  :  predict.coxph with type = 'risk' will be used (  default  ) 
#>   -> predict survival function  :  predictSurvProb.coxph will be used (  default  ) 
#>   -> predict cumulative hazard function  :  -log(predict_survival_function) will be used (  default  ) 
#>   -> model_info        :  package survival , ver. 3.7.0 , task survival (  default  ) 
#>   A new explainer has been created!

We use the explainer and the model_survshap() function to calculate SurvSHAP(t) explanations. We can specify the observations for which we want to calculate the explanations. In this example we calculate the explanations for the first 20 observations in the veteran dataset. Note: The background for generating SHAP values is the data field of the explainer! If you want to calculate explanations with a background that is not the training data, you need to manually specify the data argument, when creating the explainer.

shap <- model_survshap(exp, veteran[c(1:4, 17:20, 110:113, 126:129), -c(3,4)])

We plot these explanations using the plot.aggregated_surv_shap() function. By default the plot features 2 panels, the one on the left depicts overall importance of variables. The panel on the right demonstrates the time-dependent importance of each variable calculated as the mean absolute SHAP value at each time point across all observations.

plot(shap)

The plot.aggregated_surv_shap() function can also be used to plot the explanations for a single variable. The variable argument specifies the variable for which the explanations are plotted. The geom argument specifies the type of plot. For geom = "profile" a plot is generated that shows the mean SHAP value (averaged across the time domain) depending on the value of the variable.

plot(shap, variable = "karno", geom = "profile")
#> Warning in plot_shap_global_profile(x = x, ... = ..., title = title, subtitle =
#> subtitle, : `color_variable` was not specified, the first from the result will
#> be used.

For geom = "beeswarm" a bee swarm plot is generated that shows the SHAP values for each observation. The swarm plot is a good way to assess the distribution of SHAP values for each variable.

plot(shap, geom = "beeswarm")