--- title: "Package usage" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Package usage} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- This vignette summarizes the functionality of the `survex` package by comparing two survival models. A Cox Proportional Hazards model is analysed alongside a Random Survival Forest, showcasing the functionality of the explanations and finding differences and similarities between the two models. ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) ``` # Model and explainer creation For the purpose of this presentation, the `veteran` dataset from the `survival` package will be used. The first step is the creation of the models that will be used for making predictions and their explainers. It is important to note that for the `explain()` function to be able to extract the data automatically, we either have to set certain parameters while creating the `coxph` model or provide them manually. We chose to set the required parameters. The creation of an explainer for the random survival forest doesn't require any additional steps. ```{r setup} library(survex) library(survival) set.seed(123) vet <- survival::veteran cph <- coxph(Surv(time, status)~., data = vet, model = TRUE, x = TRUE) cph_exp <- explain(cph) rsf <- randomForestSRC::rfsrc(Surv(time, status)~., data = vet) rsf_exp <- explain(rsf) ``` However, for some models, not all data can be extracted automatically. If we want to create an explainer for a Random Survival Forest from the `ranger` package, we need to supply `data`, and `y` on our own. It is important to remember, that we should supply the data parameter **without** the columns containing survival information. ``` {r} library(ranger) ranger_rsf <- ranger(Surv(time, status)~., data = vet) ranger_rsf_exp <- explain(ranger_rsf, data = vet[, -c(3,4)], y = Surv(vet$time, vet$status)) ``` # Making predictions From this point onward, we operate only on the explainer objects (`cph_exp` and `rsf_exp`) as they are the standardized wrappers for the models. A useful feature of an explainer is the ability to make predictions (of risk scores, as well as survival and cumulative hazard functions) in a unified way independently of the underlying model. ``` {r} predict(cph_exp, veteran[1:2,], output_type="risk") predict(rsf_exp, veteran[1:2,], output_type="risk") predict(cph_exp, veteran[1:2,], output_type="survival", times=seq(1, 600, 100)) predict(rsf_exp, veteran[1:2,], output_type="survival", times=seq(1, 600, 100)) predict(cph_exp, veteran[1:2,], output_type="chf", times=seq(1, 600, 100)) predict(rsf_exp, veteran[1:2,], output_type="chf", times=seq(1, 600, 100)) ``` # Measuring performance Another helpful thing is the functionality for calculating different metrics of the models. For this we use the `model_performance()` function. It calculates a set of performance measures we can plot next to each other and easily compare. ``` {r warning=FALSE} mp_cph <- model_performance(cph_exp) mp_rsf <- model_performance(rsf_exp) plot(mp_cph, mp_rsf) ``` We can also plot the scalar metrics in the form of bar plots. ```{r, include=FALSE} dev.off() ``` ``` {r} plot(mp_cph, mp_rsf, metrics_type="scalar") ``` From this comparison, we see that the Random Survival Forest model is better in the Brier score metric (lower is better), Cumulative/Dynamic AUC metric (higher is better), and their integrated versions, as well as when we consider the concordance index. Therefore, it is a better candidate for making predictions but lacks interpretability compared to the Proportional Hazards model. # Global explanations ## Variable importance Next, we check how each variable influences the models' predictions on a global level. For this purpose, we use the `model_parts()` function. It calculates [permutational variable importance](https://ema.drwhy.ai/featureImportance.html) with the difference being that the loss function is time-dependent (by default it is `loss_brier_score()`), so the influence of each variable can be different at each considered time point. ```{r, include=FALSE} dev.off() ``` ``` {r} model_parts_rsf <- model_parts(rsf_exp) model_parts_cph <- model_parts(cph_exp) plot(model_parts_cph,model_parts_rsf) ``` For both models, the permutation of the `karno` variable leads to the highest increase in the loss function, with the second being `celltype`. These two variables are the most important for models making predictions. We can use another loss function to ensure this observation is consistent. Let's use `loss_one_minus_cd_auc()`, but let's also change the plot type to show the difference between the loss function after a given variable's permutation and the loss of the full model with all variables. This means that the values on the y-axis represent only the change in the loss function after permuting each variable. ```{r, include=FALSE} dev.off() ``` ``` {r} model_parts_rsf_auc <- model_parts(rsf_exp, loss_function=loss_one_minus_cd_auc, type="difference") model_parts_cph_auc <- model_parts(cph_exp, loss_function=loss_one_minus_cd_auc, type="difference") # NOTE: this may take a long time, so a progress bar is available. To enable it, use: # progressr::with_progress(model_parts(rsf_exp, loss_function=loss_one_minus_cd_auc, type="difference")) plot(model_parts_cph_auc,model_parts_rsf_auc) ``` We observe that the results are consistent with `loss_brier_score()`. That's good news - a simple sanity check that these variables are the most important for these models. **Important note (new functionality):** We can also measure the importance of variables using global SurvSHAP(t) explanations. Details on how to do so, are presented in this vignette: `vignette("global-survshap")`. ## Partial dependence **Important note (new functionality):** A more detailed description of the partial dependence explanation, with added functionality is presented in this vignette: `vignette("pdp")`. The next type of global explanation this package provides is partial dependence plots. This is calculated using the `model_profile()` function. These plots show how setting one variable to a different value would, on average, influence the model's prediction. Again, this is an extension of [partial dependence](https://ema.drwhy.ai/partialDependenceProfiles.html) known from regression and classification tasks, applied to survival models by extending it to take the time dimension into account. Note that we need to set the `categorical_variables` parameter in order to avoid nonsensical values, such as the treatment value of 0.5. All factors are automatically detected as categories, but if you want to treat a numerical variable as a categorical one, you need to set it here. ```{r, include=FALSE} dev.off() ``` ``` {r, fig.height=18} model_profile_cph <- model_profile(cph_exp, categorical_variables=c("trt", "prior")) plot(model_profile_cph, facet_ncol = 1) ``` From the plot, we see that for the proportional hazards model, the `prior` and `diagtime` variables are not very important. The plotted bands are very thin, almost overlapping, so the overall prediction will be similar no matter what value these variables take. On the other hand, the `karno` variable has a very wide band which means that even a small change in its value causes a big difference in the predicted survival function. We can also see that its lower values indicate a lower chance of survival (survival function decreases quicker). We can also plot the same information for the random survival forest. This time let's change the way of plotting numerical variables, instead of the values of the variables being represented by the colors and survival function values on the y-axis, let's swap them. ```{r, include=FALSE} dev.off() ``` ``` {r, fig.height=18} model_profile_rsf <- model_profile(rsf_exp, categorical_variables=c("trt", "prior")) plot(model_profile_rsf, facet_ncol = 1, numerical_plot_type = "contours") ``` This type of plot also gives us valuable insight that is easy to overlook in the other type. For example, we see a sharp drop in survival function values around `diagtime=25`. We also observe that the most significant influence of the `karno` variable is consistent across the proportional hazards and random survival forest. # Local explanations ## Local variable attributions Another kind of functionality provided by this package is local explanations. The `predict_parts()` function can be used to assess the importance of variables while making predictions for a selected observation. This can be done by two methods, SurvSHAP(t) and SurvLIME. ### SurvSHAP(t) [SurvSHAP(t)](https://www.sciencedirect.com/science/article/pii/S0950705122013302) explanations are an extension of SHAP values for survival models. They show a breakdown of the prediction into individual variables. ```{r, include=FALSE} dev.off() ``` ``` {r} predict_parts_cph_32 <- predict_parts(cph_exp, veteran[32,]) predict_parts_rsf_32 <- predict_parts(rsf_exp, veteran[32,]) plot(predict_parts_cph_32, predict_parts_rsf_32) ``` ```{r, include=FALSE} dev.off() ``` ```{r} predict_parts_cph_12 <- predict_parts(cph_exp, veteran[12,]) predict_parts_rsf_12 <- predict_parts(rsf_exp, veteran[12,]) plot(predict_parts_cph_12, predict_parts_rsf_12) ``` On the first plot, we see that for observation 32 from the `veteran` data set, the value of `karno` variable improves the chances of survival of this individual. In contrast, the value of the `celltype` variable decreases them. This is true for both explained models. On the second plot, it can be seen, that for observation 12, the situation is flipped, `celltype` increases the chances of survival while `karno` decreases them. Interestingly, for the random survival forest, one of the most influential variables is `diagtime`, which the proportional hazards model almost ignores. ### SurvLIME A different way of attributing variable importance is provided by the [SurvLIME](https://www.sciencedirect.com/science/article/pii/S0950705120304044?casa_token=3ikZP0SAYBQAAAAA:Tnez_adhQq0zubZQVPdg6EMZd9JCIfDTrLqekkfQoMBXgvzrmaSTl9fDFrey7fL4S21V7uxZ) method. It works by finding a surrogate proportional hazards model that approximates the survival model at the local area around an observation of interest. Variable importance is then attributed using the coefficients of the found model. ```{r, include=FALSE} dev.off() ``` ``` {r} predict_parts_cph_12_lime <- predict_parts(cph_exp, veteran[12,], type="survlime") predict_parts_rsf_12_lime <- predict_parts(rsf_exp, veteran[12,], type="survlime") plot(predict_parts_cph_12_lime, type="local_importance") ``` ```{r, include=FALSE} dev.off() ``` ```{r} plot(predict_parts_rsf_12_lime, type="local_importance") ``` The left part of the plot shows which variables are most important and if their value increases or lowers the chances of survival, whereas the right shows the black-box model prediction together with the one from the found surrogate model. This is useful information because the closer these functions are, the more accurate the explanation can be. ## Ceteris paribus Another explanation technique provided by this package is [ceteris paribus profiles](https://ema.drwhy.ai/ceterisParibus.html). They show how the prediction changes when we change the value of one variable at a time. We can think of them as the equivalent of partial dependence plots but applied to a single observation. The `predict_profile()` function is used to make these explanations. ```{r, include=FALSE} dev.off() ``` ``` {r, fig.height=18} predict_profile_cph_32 <- predict_profile(cph_exp, veteran[32,], categorical_variables=c("trt", "prior")) plot(predict_profile_cph_32, facet_ncol=1) ``` ```{r, include=FALSE} dev.off() ``` ```{r, fig.height=18} predict_profile_rsf_32 <- predict_profile(rsf_exp, veteran[32,], categorical_variables=c("trt", "prior")) plot(predict_profile_rsf_32, facet_ncol=1) ``` These plots also give a lot of valuable insight. For example, we see that the `prior` variable does not influence the predictions very much, as the lines representing survival functions for its different values almost overlap. We also observe that the `celltype` values of `"small"` and `"adeno"`, as well as `"large"` and `"squamous"` result in almost the same prediction for this observation. It can be seen that the most important variable for this observation is `karno`, as the differences in the survival function are the greatest, with low values indicating lower chances of survival. In contrast, high values of the `diagtime` variable seem to indicate lower chances of survival.