Title: | Explainable Machine Learning in Survival Analysis |
---|---|
Description: | Survival analysis models are commonly used in medicine and other areas. Many of them are too complex to be interpreted by human. Exploration and explanation is needed, but standard methods do not give a broad enough picture. 'survex' provides easy-to-apply methods for explaining survival models, both complex black-boxes and simpler statistical models. They include methods specific to survival analysis such as SurvSHAP(t) introduced in Krzyzinski et al., (2023) <doi:10.1016/j.knosys.2022.110234>, SurvLIME described in Kovalev et al., (2020) <doi:10.1016/j.knosys.2020.106164> as well as extensions of existing ones described in Biecek et al., (2021) <doi:10.1201/9780429027192>. |
Authors: | Mikołaj Spytek [aut, cre] , Mateusz Krzyziński [aut] , Sophie Langbein [aut], Hubert Baniecki [aut] , Lorenz A. Kapsner [ctb] , Przemyslaw Biecek [aut] |
Maintainer: | Mikołaj Spytek <[email protected]> |
License: | GPL (>= 3) |
Version: | 1.2.0.9001 |
Built: | 2024-11-08 05:17:22 UTC |
Source: | https://github.com/modeloriented/survex |
A function for calculating the Brier score for a survival model.
brier_score(y_true = NULL, risk = NULL, surv = NULL, times = NULL) loss_brier_score(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
brier_score(y_true = NULL, risk = NULL, surv = NULL, times = NULL) loss_brier_score(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
y_true |
a |
risk |
ignored, left for compatibility with other metrics |
surv |
a matrix containing the predicted survival functions for the considered observations, each row represents a single observation, whereas each column one time point |
times |
a vector of time points at which the survival function was evaluated |
Brier score is used to evaluate the performance of a survival model, based on the squared distance between the predicted survival function and the actual event time, weighted to account for censored observations.
numeric from 0 to 1, lower scores are better (Brier score of 0.25 represents a model which returns always returns 0.5 as the predicted survival function)
[1] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3.
[2] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) brier_score(y, surv = surv, times = times) loss_brier_score(y, surv = surv, times = times)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) brier_score(y, surv = surv, times = times) loss_brier_score(y, surv = surv, times = times)
A function to compute the Harrell's concordance index of a survival model.
c_index(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
c_index(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
y_true |
a |
risk |
a numeric vector of risk scores corresponding to each observation |
surv |
ignored, left for compatibility with other metrics |
times |
ignored, left for compatibility with other metrics |
numeric from 0 to 1, higher values indicate better performance
[1] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152.
library(survival) library(survex) rotterdam <- survival::rotterdam rotterdam$year <- NULL cox_rotterdam_rec <- coxph(Surv(rtime, recur) ~ ., data = rotterdam, model = TRUE, x = TRUE, y = TRUE ) coxph_explainer <- explain(cox_rotterdam_rec) risk <- coxph_explainer$predict_function(coxph_explainer$model, coxph_explainer$data) c_index(y_true = coxph_explainer$y, risk = risk)
library(survival) library(survex) rotterdam <- survival::rotterdam rotterdam$year <- NULL cox_rotterdam_rec <- coxph(Surv(rtime, recur) ~ ., data = rotterdam, model = TRUE, x = TRUE, y = TRUE ) coxph_explainer <- explain(cox_rotterdam_rec) risk <- coxph_explainer$predict_function(coxph_explainer$model, coxph_explainer$data) c_index(y_true = coxph_explainer$y, risk = risk)
This function calculates the Cumulative/Dynamic AUC metric for a survival model. It is done using the estimator proposed proposed by Uno et al. [1], and Hung and Chang [2].
cd_auc(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
cd_auc(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
y_true |
a |
risk |
ignored, left for compatibility with other metrics |
surv |
a matrix containing the predicted survival functions for the considered observations, each row represents a single observation, whereas each column one time point |
times |
a vector of time points at which the survival function was evaluated |
C/D AUC is an extension of the AUC metric known from classification models. Its values represent the model's performance at specific time points. It can be integrated over the considered time range.
a numeric vector of length equal to the length of the times vector, each value (from the range from 0 to 1) represents the AUC metric at a specific time point, with higher values indicating better performance.
[1] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537.
[2] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679.
loss_one_minus_cd_auc()
integrated_cd_auc()
brier_score()
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) cd_auc(y, surv = surv, times = times)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) cd_auc(y, surv = surv, times = times)
Helper function to transform between CHF and survival function
cumulative_hazard_to_survival(hazard_functions)
cumulative_hazard_to_survival(hazard_functions)
hazard_functions |
matrix or vector, with each row representing a cumulative hazard function |
A matrix or vector transformed to the form of a survival function.
library(survex) vec <- c(1, 2, 3, 4, 5) matr <- matrix(c(1, 2, 3, 2, 4, 6), ncol = 3) cumulative_hazard_to_survival(vec) cumulative_hazard_to_survival(matr)
library(survex) vec <- c(1, 2, 3, 4, 5) matr <- matrix(c(1, 2, 3, 2, 4, 6), ncol = 3) cumulative_hazard_to_survival(vec) cumulative_hazard_to_survival(matr)
Black-box models have vastly different structures. explain_survival()
returns an explainer object that can be further processed for creating
prediction explanations and their visualizations. This function is used to manually
create explainers for models not covered by the survex
package. For selected
models the extraction of information can be done automatically. To do
this, you can call the explain()
function for survival models from mlr3proba
, censored
,
randomForestSRC
, ranger
, survival
packages and any other model
with pec::predictSurvProb()
method.
explain_survival( model, data = NULL, y = NULL, predict_function = NULL, predict_function_target_column = NULL, residual_function = NULL, weights = NULL, ..., label = NULL, verbose = TRUE, colorize = !isTRUE(getOption("knitr.in.progress")), model_info = NULL, type = NULL, times = NULL, times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL ) explain( model, data = NULL, y = NULL, predict_function = NULL, predict_function_target_column = NULL, residual_function = NULL, weights = NULL, ..., label = NULL, verbose = TRUE, colorize = !isTRUE(getOption("knitr.in.progress")), model_info = NULL, type = NULL ) ## Default S3 method: explain( model, data = NULL, y = NULL, predict_function = NULL, predict_function_target_column = NULL, residual_function = NULL, weights = NULL, ..., label = NULL, verbose = TRUE, colorize = !isTRUE(getOption("knitr.in.progress")), model_info = NULL, type = NULL )
explain_survival( model, data = NULL, y = NULL, predict_function = NULL, predict_function_target_column = NULL, residual_function = NULL, weights = NULL, ..., label = NULL, verbose = TRUE, colorize = !isTRUE(getOption("knitr.in.progress")), model_info = NULL, type = NULL, times = NULL, times_generation = "survival_quantiles", predict_survival_function = NULL, predict_cumulative_hazard_function = NULL ) explain( model, data = NULL, y = NULL, predict_function = NULL, predict_function_target_column = NULL, residual_function = NULL, weights = NULL, ..., label = NULL, verbose = TRUE, colorize = !isTRUE(getOption("knitr.in.progress")), model_info = NULL, type = NULL ) ## Default S3 method: explain( model, data = NULL, y = NULL, predict_function = NULL, predict_function_target_column = NULL, residual_function = NULL, weights = NULL, ..., label = NULL, verbose = TRUE, colorize = !isTRUE(getOption("knitr.in.progress")), model_info = NULL, type = NULL )
model |
object - a survival model to be explained |
data |
data.frame - data which will be used to calculate the explanations. If not provided, then it will be extracted from the model if possible. It should not contain the target columns. NOTE: If the target variable is present in the |
y |
|
predict_function |
function taking 2 arguments - |
predict_function_target_column |
unused, left for compatibility with DALEX |
residual_function |
unused, left for compatibility with DALEX |
weights |
unused, left for compatibility with DALEX |
... |
additional arguments, passed to |
label |
character - the name of the model. Used to differentiate on visualizations with multiple explainers. By default it's extracted from the 'class' attribute of the model if possible. |
verbose |
logical, if TRUE (default) then diagnostic messages will be printed |
colorize |
logical, if TRUE (default) then WARNINGS, ERRORS and NOTES are colorized. Will work only in the R console. By default it is FALSE while knitting and TRUE otherwise. |
model_info |
a named list ( |
type |
type of a model, by default |
times |
numeric, a vector of times at which the survival function and cumulative hazard function should be evaluated for calculations |
times_generation |
either |
predict_survival_function |
function taking 3 arguments |
predict_cumulative_hazard_function |
function taking 3 arguments |
It is a list containing the following elements:
model
- the explained model.
data
- the dataset used for training.
y
- response for observations from data
.
residuals
- calculated residuals.
predict_function
- function that may be used for model predictions, shall return a single numerical value for each observation.
residual_function
- function that returns residuals, shall return a single numerical value for each observation.
class
- class/classes of a model.
label
- label of explainer.
model_info
- named list containing basic information about model, like package, version of package and type.
times
- a vector of times, that are used for evaluation of survival function and cumulative hazard function by default
predict_survival_function
- function that is used for model predictions in the form of survival function
predict_cumulative_hazard_function
- function that is used for model predictions in the form of cumulative hazard function
library(survival) library(survex) cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE ) cph_exp <- explain(cph) rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) rsf_src_exp <- explain(rsf_src) library(censored, quietly = TRUE) bt <- parsnip::boost_tree() %>% parsnip::set_engine("mboost") %>% parsnip::set_mode("censored regression") %>% generics::fit(survival::Surv(time, status) ~ ., data = veteran) bt_exp <- explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status)) ###### explain_survival() ###### cph <- coxph(Surv(time, status) ~ ., data = veteran) veteran_data <- veteran[, -c(3, 4)] veteran_y <- Surv(veteran$time, veteran$status) risk_pred <- function(model, newdata) predict(model, newdata, type = "risk") surv_pred <- function(model, newdata, times) pec::predictSurvProb(model, newdata, times) chf_pred <- function(model, newdata, times) -log(surv_pred(model, newdata, times)) manual_cph_explainer <- explain_survival( model = cph, data = veteran_data, y = veteran_y, predict_function = risk_pred, predict_survival_function = surv_pred, predict_cumulative_hazard_function = chf_pred, label = "manual coxph" )
library(survival) library(survex) cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE ) cph_exp <- explain(cph) rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) rsf_src_exp <- explain(rsf_src) library(censored, quietly = TRUE) bt <- parsnip::boost_tree() %>% parsnip::set_engine("mboost") %>% parsnip::set_mode("censored regression") %>% generics::fit(survival::Surv(time, status) ~ ., data = veteran) bt_exp <- explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status)) ###### explain_survival() ###### cph <- coxph(Surv(time, status) ~ ., data = veteran) veteran_data <- veteran[, -c(3, 4)] veteran_y <- Surv(veteran$time, veteran$status) risk_pred <- function(model, newdata) predict(model, newdata, type = "risk") surv_pred <- function(model, newdata, times) pec::predictSurvProb(model, newdata, times) chf_pred <- function(model, newdata, times) -log(surv_pred(model, newdata, times)) manual_cph_explainer <- explain_survival( model = cph, data = veteran_data, y = veteran_y, predict_function = risk_pred, predict_survival_function = surv_pred, predict_cumulative_hazard_function = chf_pred, label = "manual coxph" )
Helper function to extract local SurvSHAP(t) explanation from global one. Can be can be useful for creating SurvSHAP(t) plots for single observations.
extract_predict_survshap(aggregated_survshap, index)
extract_predict_survshap(aggregated_survshap, index)
aggregated_survshap |
an object of class |
index |
a numeric value, position of an observation to be extracted in the result of global explanation |
An object of classes c("predict_parts_survival", "surv_shap")
. It is a list with the element result
containing the results of the explanation.
veteran <- survival::veteran rsf_ranger <- ranger::ranger( survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_ranger_exp <- explain( rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE ) ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, new_observation = veteran[ c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status") ] ) local_survshap_1 <- extract_predict_survshap(ranger_global_survshap, index = 1) plot(local_survshap_1)
veteran <- survival::veteran rsf_ranger <- ranger::ranger( survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_ranger_exp <- explain( rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE ) ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, new_observation = veteran[ c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status") ] ) local_survshap_1 <- extract_predict_survshap(ranger_global_survshap, index = 1) plot(local_survshap_1)
This function calculates the integrated Brier score metric for a survival model.
integrated_brier_score(y_true = NULL, risk = NULL, surv = NULL, times = NULL) loss_integrated_brier_score( y_true = NULL, risk = NULL, surv = NULL, times = NULL )
integrated_brier_score(y_true = NULL, risk = NULL, surv = NULL, times = NULL) loss_integrated_brier_score( y_true = NULL, risk = NULL, surv = NULL, times = NULL )
y_true |
a |
risk |
ignored, left for compatibility with other metrics |
surv |
a matrix containing the predicted survival functions for the considered observations, each row represents a single observation, whereas each column one time point |
times |
a vector of time points at which the survival function was evaluated |
It is useful to see how a model performs as a whole, not at specific time points, for example for easier comparison. This function allows for calculating the integral of Brier score metric numerically using the trapezoid method.
numeric from 0 to 1, lower values indicate better performance
[1] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3.
[2] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545.
brier_score()
integrated_cd_auc()
loss_one_minus_integrated_cd_auc()
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) # calculating directly integrated_brier_score(y, surv = surv, times = times)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) # calculating directly integrated_brier_score(y, surv = surv, times = times)
This function calculates the integrated Cumulative/Dynamic AUC metric for a survival model.
integrated_cd_auc(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
integrated_cd_auc(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
y_true |
a |
risk |
ignored, left for compatibility with other metrics |
surv |
a matrix containing the predicted survival functions for the considered observations, each row represents a single observation, whereas each column one time point |
times |
a vector of time points at which the survival function was evaluated |
It is useful to see how a model performs as a whole, not at specific time points, for example for easier comparison. This function allows for calculating the integral of the C/D AUC metric numerically using the trapezoid method.
numeric from 0 to 1, higher values indicate better performance
#' @section References:
[1] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537.
[2] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679.
cd_auc()
loss_one_minus_cd_auc()
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) integrated_cd_auc(y, surv = surv, times = times)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) integrated_cd_auc(y, surv = surv, times = times)
This function allows for usage of standardized measures from the mlr3proba package with survex
.
loss_adapt_mlr3proba(measure, reverse = FALSE, ...)
loss_adapt_mlr3proba(measure, reverse = FALSE, ...)
measure |
|
reverse |
|
... |
|
a function with standardized parameters (y_true
, risk
, surv
, times
) that can be used to calculate loss
if(FALSE){ measure <- msr("surv.calib_beta") mlr_measure <- loss_adapt_mlr3proba(measure) }
if(FALSE){ measure <- msr("surv.calib_beta") mlr_measure <- loss_adapt_mlr3proba(measure) }
This function allows for creating a function for calculation of integrated metrics based on a time dependent metric. A possibility to cut off the data at certain quantiles is implemented, as well as weighting the integrated metric by max time and marginal survival function [1]
loss_integrate(loss_function, ..., normalization = NULL, max_quantile = 1)
loss_integrate(loss_function, ..., normalization = NULL, max_quantile = 1)
loss_function |
|
... |
|
normalization |
|
max_quantile |
|
a function that can be used to calculate metrics (with parameters y_true
, risk
, surv
, and times
)
[1] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545.
This function subtracts the C-index metric from one to obtain a loss function whose lower values indicate better model performance (useful for permutational feature importance)
loss_one_minus_c_index(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
loss_one_minus_c_index(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
y_true |
a |
risk |
a numeric vector of risk scores corresponding to each observation |
surv |
ignored, left for compatibility with other metrics |
times |
ignored, left for compatibility with other metrics |
numeric from 0 to 1, lower values indicate better performance
[1] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152.
library(survival) library(survex) rotterdam <- survival::rotterdam rotterdam$year <- NULL cox_rotterdam_rec <- coxph(Surv(rtime, recur) ~ ., data = rotterdam, model = TRUE, x = TRUE, y = TRUE ) coxph_explainer <- explain(cox_rotterdam_rec) risk <- coxph_explainer$predict_function(coxph_explainer$model, coxph_explainer$data) loss_one_minus_c_index(y_true = coxph_explainer$y, risk = risk)
library(survival) library(survex) rotterdam <- survival::rotterdam rotterdam$year <- NULL cox_rotterdam_rec <- coxph(Surv(rtime, recur) ~ ., data = rotterdam, model = TRUE, x = TRUE, y = TRUE ) coxph_explainer <- explain(cox_rotterdam_rec) risk <- coxph_explainer$predict_function(coxph_explainer$model, coxph_explainer$data) loss_one_minus_c_index(y_true = coxph_explainer$y, risk = risk)
This function subtracts the C/D AUC metric from one to obtain a loss function whose lower values indicate better model performance (useful for permutational feature importance)
loss_one_minus_cd_auc(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
loss_one_minus_cd_auc(y_true = NULL, risk = NULL, surv = NULL, times = NULL)
y_true |
a |
risk |
ignored, left for compatibility with other metrics |
surv |
a matrix containing the predicted survival functions for the considered observations, each row represents a single observation, whereas each column one time point |
times |
a vector of time points at which the survival function was evaluated |
a numeric vector of length equal to the length of the times vector, each value (from the range from 0 to 1) represents 1 - AUC metric at a specific time point, with lower values indicating better performance.
#' @section References:
[1] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537.
[2] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) loss_one_minus_cd_auc(y, surv = surv, times = times)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) loss_one_minus_cd_auc(y, surv = surv, times = times)
This function subtracts integrated the C/D AUC metric from one to obtain a loss function whose lower values indicate better model performance (useful for permutational feature importance)
loss_one_minus_integrated_cd_auc( y_true = NULL, risk = NULL, surv = NULL, times = NULL )
loss_one_minus_integrated_cd_auc( y_true = NULL, risk = NULL, surv = NULL, times = NULL )
y_true |
a |
risk |
ignored, left for compatibility with other metrics |
surv |
a matrix containing the predicted survival functions for the considered observations, each row represents a single observation, whereas each column one time point |
times |
a vector of time points at which the survival function was evaluated |
numeric from 0 to 1, lower values indicate better performance
#' @section References:
[1] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537.
[2] Hung, Hung, and Chin‐Tsang Chiang. "Optimal composite markers for time‐dependent receiver operating characteristic curves with censored survival data." Scandinavian Journal of Statistics 37.4 (2010): 664-679.
integrated_cd_auc()
cd_auc()
loss_one_minus_cd_auc()
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) # calculating directly loss_one_minus_integrated_cd_auc(y, surv = surv, times = times)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) y <- cph_exp$y times <- cph_exp$times surv <- cph_exp$predict_survival_function(cph, cph_exp$data, times) # calculating directly loss_one_minus_integrated_cd_auc(y, surv = surv, times = times)
This function calculates martingale and deviance residuals.
model_diagnostics(explainer) ## S3 method for class 'surv_explainer' model_diagnostics(explainer)
model_diagnostics(explainer) ## S3 method for class 'surv_explainer' model_diagnostics(explainer)
explainer |
an explainer object - model preprocessed by the |
An object of class c("model_diagnostics_survival")
. It's a list with the explanations in the result
element.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) cph_residuals <- model_diagnostics(cph_exp) rsf_residuals <- model_diagnostics(rsf_ranger_exp) head(cph_residuals$result) plot(cph_residuals, rsf_residuals, xvariable = "age") plot(cph_residuals, rsf_residuals, plot_type = "Cox-Snell")
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) cph_residuals <- model_diagnostics(cph_exp) rsf_residuals <- model_diagnostics(rsf_ranger_exp) head(cph_residuals$result) plot(cph_residuals, rsf_residuals, xvariable = "age") plot(cph_residuals, rsf_residuals, plot_type = "Cox-Snell")
This function calculates variable importance as a change in the loss function after the variable values permutations.
model_parts(explainer, ...) ## S3 method for class 'surv_explainer' model_parts( explainer, loss_function = survex::loss_brier_score, ..., type = "difference", output_type = "survival", N = 1000 )
model_parts(explainer, ...) ## S3 method for class 'surv_explainer' model_parts( explainer, loss_function = survex::loss_brier_score, ..., type = "difference", output_type = "survival", N = 1000 )
explainer |
an explainer object - model preprocessed by the |
... |
Arguments passed on to
|
loss_function |
a function that will be used to assess variable importance, by default |
type |
a character vector, if |
output_type |
either |
N |
number of observations that should be sampled for calculation of variable importance. If |
Note: This function can be run within progressr::with_progress()
to display a progress bar, as the execution can take long, especially on large datasets.
An object of class c("model_parts_survival", "surv_feature_importance")
. It's a list with the explanations in the result
element.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) cph_model_parts_brier <- model_parts(cph_exp) print(head(cph_model_parts_brier$result)) plot(cph_model_parts_brier) rsf_ranger_model_parts <- model_parts(rsf_ranger_exp) print(head(rsf_ranger_model_parts$result)) plot(cph_model_parts_brier, rsf_ranger_model_parts)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) cph_model_parts_brier <- model_parts(cph_exp) print(head(cph_model_parts_brier$result)) plot(cph_model_parts_brier) rsf_ranger_model_parts <- model_parts(rsf_ranger_exp) print(head(rsf_ranger_model_parts$result)) plot(cph_model_parts_brier, rsf_ranger_model_parts)
This function calculates metrics for survival models. The metrics calculated are C/D AUC, Brier score, and their integrated versions, as well as concordance index. It also can calculate ROC curves for specific selected time points.
model_performance(explainer, ...) ## S3 method for class 'surv_explainer' model_performance( explainer, ..., type = "metrics", metrics = c(`C-index` = c_index, `Integrated C/D AUC` = integrated_cd_auc, `Brier score` = brier_score, `Integrated Brier score` = integrated_brier_score, `C/D AUC` = cd_auc), times = NULL )
model_performance(explainer, ...) ## S3 method for class 'surv_explainer' model_performance( explainer, ..., type = "metrics", metrics = c(`C-index` = c_index, `Integrated C/D AUC` = integrated_cd_auc, `Brier score` = brier_score, `Integrated Brier score` = integrated_brier_score, `C/D AUC` = cd_auc), times = NULL )
explainer |
an explainer object - model preprocessed by the |
... |
other parameters, currently ignored |
type |
character, either |
metrics |
a named vector containing the metrics to be calculated. The values should be standardized loss functions. The functions can be supplied manually but has to have these named parameters ( |
times |
a numeric vector of times. If |
An object of class "model_performance_survival"
. It's a list of metric values calculated for the model. It contains:
Harrell's concordance index [1]
Brier score [2, 3]
C/D AUC using the estimator proposed by Uno et. al [4]
integral of the Brier score
integral of the C/D AUC
[1] Harrell, F.E., Jr., et al. "Regression modelling strategies for improved prognostic prediction." Statistics in Medicine 3.2 (1984): 143-152.
[2] Brier, Glenn W. "Verification of forecasts expressed in terms of probability." Monthly Weather Review 78.1 (1950): 1-3.
[3] Graf, Erika, et al. "Assessment and comparison of prognostic classification schemes for survival data." Statistics in Medicine 18.17‐18 (1999): 2529-2545.
[4] Uno, Hajime, et al. "Evaluating prediction rules for t-year survivors with censored regression models." Journal of the American Statistical Association 102.478 (2007): 527-537.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) rsf_src_exp <- explain(rsf_src) cph_model_performance <- model_performance(cph_exp) rsf_ranger_model_performance <- model_performance(rsf_ranger_exp) rsf_src_model_performance <- model_performance(rsf_src_exp) print(cph_model_performance) plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance, metrics_type = "scalar" ) plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance) cph_model_performance_roc <- model_performance(cph_exp, type = "roc", times = c(100, 250, 500)) plot(cph_model_performance_roc)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) rsf_src_exp <- explain(rsf_src) cph_model_performance <- model_performance(cph_exp) rsf_ranger_model_performance <- model_performance(rsf_ranger_exp) rsf_src_model_performance <- model_performance(rsf_src_exp) print(cph_model_performance) plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance, metrics_type = "scalar" ) plot(rsf_ranger_model_performance, cph_model_performance, rsf_src_model_performance) cph_model_performance_roc <- model_performance(cph_exp, type = "roc", times = c(100, 250, 500)) plot(cph_model_performance_roc)
This function calculates explanations on a dataset level that help explore model response as a function of selected variables. The explanations are calculated as an extension of Partial Dependence Profiles with the inclusion of the time dimension.
model_profile( explainer, variables = NULL, N = 100, ..., groups = NULL, k = NULL, type = "partial", center = FALSE, output_type = "survival" ) ## S3 method for class 'surv_explainer' model_profile( explainer, variables = NULL, N = 100, ..., categorical_variables = NULL, grid_points = 51, variable_splits_type = "uniform", groups = NULL, k = NULL, center = FALSE, type = "partial", output_type = "survival" )
model_profile( explainer, variables = NULL, N = 100, ..., groups = NULL, k = NULL, type = "partial", center = FALSE, output_type = "survival" ) ## S3 method for class 'surv_explainer' model_profile( explainer, variables = NULL, N = 100, ..., categorical_variables = NULL, grid_points = 51, variable_splits_type = "uniform", groups = NULL, k = NULL, center = FALSE, type = "partial", output_type = "survival" )
explainer |
an explainer object - model preprocessed by the |
variables |
character, a vector of names of variables to be explained |
N |
number of observations used for the calculation of aggregated profiles. By default |
... |
other parameters passed to |
groups |
if |
k |
passed to |
type |
the type of variable profile, |
center |
logical, should profiles be centered around the average prediction |
output_type |
either |
categorical_variables |
character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the |
grid_points |
maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default |
variable_splits_type |
character, decides how variable grids should be calculated. Use |
An object of class model_profile_survival
. It is a list with the element result
containing the results of the calculation.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph) rsf_src_exp <- explain(rsf_src) cph_model_profile <- model_profile(cph_exp, output_type = "survival", variables = c("age") ) head(cph_model_profile$result) plot(cph_model_profile) rsf_model_profile <- model_profile(rsf_src_exp, output_type = "survival", variables = c("age", "celltype"), type = "accumulated" ) head(rsf_model_profile$result) plot(rsf_model_profile, variables = c("age", "celltype"), numerical_plot_type = "contours")
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph) rsf_src_exp <- explain(rsf_src) cph_model_profile <- model_profile(cph_exp, output_type = "survival", variables = c("age") ) head(cph_model_profile$result) plot(cph_model_profile) rsf_model_profile <- model_profile(rsf_src_exp, output_type = "survival", variables = c("age", "celltype"), type = "accumulated" ) head(rsf_model_profile$result) plot(rsf_model_profile, variables = c("age", "celltype"), numerical_plot_type = "contours")
This function calculates explanations on a dataset level that help explore model response as a function of selected pairs of variables. The explanations are calculated as an extension of Partial Dependence Profiles or Accumulated Local Effects with the inclusion of the time dimension.
model_profile_2d( explainer, variables = NULL, N = 100, categorical_variables = NULL, grid_points = 25, center = FALSE, variable_splits_type = "uniform", type = "partial", output_type = "survival" ) ## S3 method for class 'surv_explainer' model_profile_2d( explainer, variables = NULL, N = 100, categorical_variables = NULL, grid_points = 25, center = FALSE, variable_splits_type = "uniform", type = "partial", output_type = "survival" )
model_profile_2d( explainer, variables = NULL, N = 100, categorical_variables = NULL, grid_points = 25, center = FALSE, variable_splits_type = "uniform", type = "partial", output_type = "survival" ) ## S3 method for class 'surv_explainer' model_profile_2d( explainer, variables = NULL, N = 100, categorical_variables = NULL, grid_points = 25, center = FALSE, variable_splits_type = "uniform", type = "partial", output_type = "survival" )
explainer |
an explainer object - model preprocessed by the |
variables |
list of character vectors of length 2, names of pairs of variables to be explained |
N |
number of observations used for the calculation of aggregated profiles. By default |
categorical_variables |
character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the |
grid_points |
maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default |
center |
logical, should profiles be centered around the average prediction |
variable_splits_type |
character, decides how variable grids should be calculated. Use |
type |
the type of variable profile, |
output_type |
either |
An object of class model_profile_2d_survival
. It is a list with the element result
containing the results of the calculation.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) cph_model_profile_2d <- model_profile_2d(cph_exp, variables = list(c("age", "celltype")) ) head(cph_model_profile_2d$result) plot(cph_model_profile_2d) cph_model_profile_2d_ale <- model_profile_2d(cph_exp, variables = list(c("age", "karno")), type = "accumulated" ) head(cph_model_profile_2d_ale$result) plot(cph_model_profile_2d_ale)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) cph_model_profile_2d <- model_profile_2d(cph_exp, variables = list(c("age", "celltype")) ) head(cph_model_profile_2d$result) plot(cph_model_profile_2d) cph_model_profile_2d_ale <- model_profile_2d(cph_exp, variables = list(c("age", "karno")), type = "accumulated" ) head(cph_model_profile_2d_ale$result) plot(cph_model_profile_2d_ale)
This function computes global SHAP values.
model_survshap(explainer, ...) ## S3 method for class 'surv_explainer' model_survshap( explainer, new_observation = NULL, y_true = NULL, N = NULL, calculation_method = "kernelshap", aggregation_method = "integral", output_type = "survival", ... )
model_survshap(explainer, ...) ## S3 method for class 'surv_explainer' model_survshap( explainer, new_observation = NULL, y_true = NULL, N = NULL, calculation_method = "kernelshap", aggregation_method = "integral", output_type = "survival", ... )
explainer |
an explainer object - model preprocessed by the |
... |
additional parameters, passed to internal functions |
new_observation |
new observations for which predictions need to be explained |
y_true |
a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting |
N |
a positive integer, number of observations used as the background data |
calculation_method |
a character, either |
aggregation_method |
a character, either |
output_type |
a character, either |
If specifying y_true
, also new_observation
must be specified.
Using the argument new_observation
, global SHAP values are computed for the provided data. Otherwise,
global SHAP values are computed for the data, the explainer
was trained with.
An object of class aggregated_surv_shap
containing the computed global SHAP values.
veteran <- survival::veteran rsf_ranger <- ranger::ranger( survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_ranger_exp <- explain( rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE ) ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, new_observation = veteran[ c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status") ], y_true = survival::Surv( veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)] ), aggregation_method = "integral", calculation_method = "kernelshap", ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") plot(ranger_global_survshap, geom = "profile", color_variable = "karno")
veteran <- survival::veteran rsf_ranger <- ranger::ranger( survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_ranger_exp <- explain( rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE ) ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, new_observation = veteran[ c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status") ], y_true = survival::Surv( veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)] ), aggregation_method = "integral", calculation_method = "kernelshap", ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") plot(ranger_global_survshap, geom = "profile", color_variable = "karno")
This functions plots objects of class aggregated_surv_shap
- aggregated time-dependent
explanations of survival models created using the model_survshap()
function.
## S3 method for class 'aggregated_surv_shap' plot( x, geom = "importance", ..., title = "default", subtitle = "default", max_vars = 7, colors = NULL )
## S3 method for class 'aggregated_surv_shap' plot( x, geom = "importance", ..., title = "default", subtitle = "default", max_vars = 7, colors = NULL )
x |
an object of class |
geom |
character, one of |
... |
additional parameters passed to internal functions |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
max_vars |
maximum number of variables to be plotted (least important variables are ignored), by default 7 |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). If |
An object of the class ggplot
.
plot.aggregated_surv_shap(geom = "importance")
rug
- character, one of "all"
, "events"
, "censors"
, "none"
or NULL
. Which times to mark on the x axis in geom_rug()
.
rug_colors
- character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times.
xlab_left, ylab_right
- axis labels for left and right plots (due to different aggregation possibilities)
plot.aggregated_surv_shap(geom = "beeswarm")
no additional parameters
plot.aggregated_surv_shap(geom = "profile")
variable
- variable for which the profile is to be plotted, by default first from result data
color_variable
- variable used to denote the color, by default equal to variable
plot.aggregated_surv_shap(geom = "curves")
variable
- variable for which SurvSHAP(t) curves are to be plotted, by default first from result data
boxplot
- whether to plot functional boxplot with marked outliers or all curves colored by variable value
coef
- length of the functional boxplot's whiskers as multiple of IQR, by default 1.5
veteran <- survival::veteran rsf_ranger <- ranger::ranger( survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_ranger_exp <- explain( rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE ) ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, new_observation = veteran[ c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status") ], y_true = survival::Surv( veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)] ), aggregation_method = "integral", calculation_method = "kernelshap", ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") plot(ranger_global_survshap, geom = "profile", variable = "age", color_variable = "karno") plot(ranger_global_survshap, geom = "curves", variable = "age") plot(ranger_global_survshap, geom = "curves", variable = "age", boxplot = TRUE)
veteran <- survival::veteran rsf_ranger <- ranger::ranger( survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) rsf_ranger_exp <- explain( rsf_ranger, data = veteran[, -c(3, 4)], y = survival::Surv(veteran$time, veteran$status), verbose = FALSE ) ranger_global_survshap <- model_survshap( explainer = rsf_ranger_exp, new_observation = veteran[ c(1:4, 17:20, 110:113, 126:129), !colnames(veteran) %in% c("time", "status") ], y_true = survival::Surv( veteran$time[c(1:4, 17:20, 110:113, 126:129)], veteran$status[c(1:4, 17:20, 110:113, 126:129)] ), aggregation_method = "integral", calculation_method = "kernelshap", ) plot(ranger_global_survshap) plot(ranger_global_survshap, geom = "beeswarm") plot(ranger_global_survshap, geom = "profile", variable = "age", color_variable = "karno") plot(ranger_global_survshap, geom = "curves", variable = "age") plot(ranger_global_survshap, geom = "curves", variable = "age", boxplot = TRUE)
This function plots objects of class "model_diagnostics_survival"
created
using the model_diagnostics()
function.
## S3 method for class 'model_diagnostics_survival' plot( x, ..., plot_type = "deviance", xvariable = "index", smooth = as.logical(xvariable != "index"), facet_ncol = NULL, title = "Model diagnostics", subtitle = "default", colors = NULL )
## S3 method for class 'model_diagnostics_survival' plot( x, ..., plot_type = "deviance", xvariable = "index", smooth = as.logical(xvariable != "index"), facet_ncol = NULL, title = "Model diagnostics", subtitle = "default", colors = NULL )
x |
an object of class |
... |
additional objects of class |
plot_type |
character, either |
xvariable |
character, name of the variable to be plotted on x-axis (can be name of the variable to be drawn on the x-axis (can be any column from the |
smooth |
logical, shall the smooth line be added. Only used when |
facet_ncol |
number of columns for arranging subplots |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
colors |
character vector containing the colors to be used for plotting (containing either hex codes "#FF69B4", or names "blue"). |
An object of the class ggplot
.
This function is a wrapper for plotting model_parts
objects created for survival models
using the model_parts()
function.
## S3 method for class 'model_parts_survival' plot(x, ...)
## S3 method for class 'model_parts_survival' plot(x, ...)
x |
an object of class |
... |
additional parameters passed to the |
An object of the class ggplot
.
title
- character, title of the plot
subtitle
- character, subtitle of the plot, if NULL
automatically generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
max_vars
- maximum number of variables to be plotted (least important variables are ignored)
colors
- character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
rug
- character, one of "all"
, "events"
, "censors"
, "none"
or NULL
. Which times to mark on the x axis in geom_rug()
.
rug_colors
- character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times.
Other functions for plotting 'model_parts_survival' objects:
plot.surv_feature_importance()
library(survival) library(survex) model <- coxph(Surv(time, status) ~ ., data = veteran, x = TRUE, model = TRUE, y = TRUE) explainer <- explain(model) mp <- model_parts(explainer) plot(mp)
library(survival) library(survex) model <- coxph(Surv(time, status) ~ ., data = veteran, x = TRUE, model = TRUE, y = TRUE) explainer <- explain(model) mp <- model_parts(explainer) plot(mp)
This function is a wrapper for plotting model_performance
objects created for survival models
using the model_performance()
function.
## S3 method for class 'model_performance_survival' plot(x, ...)
## S3 method for class 'model_performance_survival' plot(x, ...)
x |
an object of class |
... |
additional parameters passed to the |
An object of the class ggplot
.
plot.surv_model_performance
x
- an object of class "surv_model_performance"
to be plotted
...
- additional objects of class "surv_model_performance"
to be plotted together
metrics
- character, names of metrics to be plotted (subset of C/D AUC", "Brier score" for metrics_type %in% c("time_dependent", "functional")
or subset of "C-index","Integrated Brier score", "Integrated C/D AUC" for metrics_type == "scalar"
), by default (NULL
) all metrics of a given type are plotted
metrics_type
- character, either one of c("time_dependent","functional")
for functional metrics or "scalar"
for scalar metrics
title
- character, title of the plot
subtitle
- character, subtitle of the plot, 'default'
automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
facet_ncol
- number of columns for arranging subplots
colors
- character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
rug
- character, one of "all"
, "events"
, "censors"
, "none"
or NULL
. Which times to mark on the x axis in geom_rug()
.
rug_colors
- character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times.
plot.surv_model_performance_rocs
x
- an object of class "surv_model_performance_rocs"
to be plotted
...
- additional objects of class "surv_model_performance_rocs"
to be plotted together
title
- character, title of the plot
subtitle
- character, subtitle of the plot, 'default'
automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
colors
- character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
facet_ncol
- number of columns for arranging subplots
Other functions for plotting 'model_performance_survival' objects:
plot.surv_model_performance_rocs()
,
plot.surv_model_performance()
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) m_perf <- model_performance(exp) plot(m_perf, metrics_type = "functional") m_perf_roc <- model_performance(exp, type = "roc", times = c(100, 300)) plot(m_perf_roc)
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) m_perf <- model_performance(exp) plot(m_perf, metrics_type = "functional") m_perf_roc <- model_performance(exp, type = "roc", times = c(100, 300)) plot(m_perf_roc)
This function plots objects of class "model_profile_2d_survival"
created
using the model_profile_2d()
function.
## S3 method for class 'model_profile_2d_survival' plot( x, ..., variables = NULL, times = NULL, marginalize_over_time = FALSE, facet_ncol = NULL, title = "default", subtitle = "default", colors = NULL )
## S3 method for class 'model_profile_2d_survival' plot( x, ..., variables = NULL, times = NULL, marginalize_over_time = FALSE, facet_ncol = NULL, title = "default", subtitle = "default", colors = NULL )
x |
an object of class |
... |
additional objects of class |
variables |
list of character vectors of length 2, names of pairs of variables to be plotted |
times |
numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If |
marginalize_over_time |
logical, if |
facet_ncol |
number of columns for arranging subplots |
title |
character, title of the plot. |
subtitle |
character, subtitle of the plot, |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") |
A collection of ggplot
objects arranged with the patchwork
package.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) cph_model_profile_2d <- model_profile_2d(cph_exp, variables = list( c("age", "celltype"), c("age", "karno") ) ) head(cph_model_profile_2d$result) plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = cph_exp$times[20]) cph_model_profile_2d_ale <- model_profile_2d(cph_exp, variables = list(c("age", "karno")), type = "accumulated" ) head(cph_model_profile_2d_ale$result) plot(cph_model_profile_2d_ale, times = cph_exp$times[c(10, 20)], marginalize_over_time = TRUE)
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) cph_model_profile_2d <- model_profile_2d(cph_exp, variables = list( c("age", "celltype"), c("age", "karno") ) ) head(cph_model_profile_2d$result) plot(cph_model_profile_2d, variables = list(c("age", "celltype")), times = cph_exp$times[20]) cph_model_profile_2d_ale <- model_profile_2d(cph_exp, variables = list(c("age", "karno")), type = "accumulated" ) head(cph_model_profile_2d_ale$result) plot(cph_model_profile_2d_ale, times = cph_exp$times[c(10, 20)], marginalize_over_time = TRUE)
This function plots objects of class "model_profile_survival"
created
using the model_profile()
function.
## S3 method for class 'model_profile_survival' plot( x, ..., geom = "time", variables = NULL, variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", times = NULL, marginalize_over_time = FALSE, plot_type = NULL, title = "default", subtitle = "default", colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
## S3 method for class 'model_profile_survival' plot( x, ..., geom = "time", variables = NULL, variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", times = NULL, marginalize_over_time = FALSE, plot_type = NULL, title = "default", subtitle = "default", colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
x |
an object of class |
... |
additional objects of class |
geom |
character, either |
variables |
character, names of the variables to be plotted. When |
variable_type |
character, either |
facet_ncol |
number of columns for arranging subplots. Only used when |
numerical_plot_type |
character, either |
times |
numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If |
marginalize_over_time |
logical, if |
plot_type |
character, one of |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue"). |
rug |
character, one of |
rug_colors |
character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. |
A collection of ggplot
objects arranged with the patchwork
package.
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) m_prof <- model_profile(exp, categorical_variables = "trt") plot(m_prof) plot(m_prof, numerical_plot_type = "contours") plot(m_prof, variables = c("trt", "age"), facet_ncol = 1) plot(m_prof, geom = "variable", variables = "karno", plot_type = "pdp+ice") plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], variables = "karno", plot_type = "pdp+ice") plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], variables = "trt", plot_type = "pdp+ice")
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) m_prof <- model_profile(exp, categorical_variables = "trt") plot(m_prof) plot(m_prof, numerical_plot_type = "contours") plot(m_prof, variables = c("trt", "age"), facet_ncol = 1) plot(m_prof, geom = "variable", variables = "karno", plot_type = "pdp+ice") plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], variables = "karno", plot_type = "pdp+ice") plot(m_prof, geom = "variable", times = exp$times[c(5, 10)], variables = "trt", plot_type = "pdp+ice")
This function plots objects of class "predict_parts_survival"
- local explanations
for survival models created using the predict_parts()
function.
## S3 method for class 'predict_parts_survival' plot(x, ...)
## S3 method for class 'predict_parts_survival' plot(x, ...)
x |
an object of class |
... |
additional parameters passed to the |
An object of the class ggplot
.
plot.surv_shap
x
- an object of class "surv_shap"
to be plotted
...
- additional objects of class surv_shap
to be plotted together
title
- character, title of the plot
subtitle
- character, subtitle of the plot, 'default'
automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
max_vars
- maximum number of variables to be plotted (least important variables are ignored)
colors
- character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
rug
- character, one of "all"
, "events"
, "censors"
, "none"
or NULL
. Which times to mark on the x axis in geom_rug()
.
rug_colors
- character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times.
plot.surv_lime
x
- an object of class "surv_lime"
to be plotted
type
- character, either "coefficients" or "local_importance", selects the type of plot
show_survival_function
- logical, if the survival function of the explanations should be plotted next to the barplot
...
- other parameters currently ignored
title
- character, title of the plot
subtitle
- character, subtitle of the plot, 'default'
automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
max_vars
- maximum number of variables to be plotted (least important variables are ignored)
colors
- character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
Other functions for plotting 'predict_parts_survival' objects:
plot.surv_lime()
,
plot.surv_shap()
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_parts_shap <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survshap") plot(p_parts_shap) p_parts_lime <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survlime") plot(p_parts_lime)
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_parts_shap <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survshap") plot(p_parts_shap) p_parts_lime <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survlime") plot(p_parts_lime)
This function plots objects of class "predict_profile_survival"
created using
the predict_profile()
function.
## S3 method for class 'predict_profile_survival' plot( x, ..., geom = "time", variables = NULL, variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", times = NULL, marginalize_over_time = FALSE, title = "default", subtitle = "default", colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
## S3 method for class 'predict_profile_survival' plot( x, ..., geom = "time", variables = NULL, variable_type = NULL, facet_ncol = NULL, numerical_plot_type = "lines", times = NULL, marginalize_over_time = FALSE, title = "default", subtitle = "default", colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
x |
an object of class |
... |
additional objects of class |
geom |
character, either |
variables |
character, names of the variables to be plotted. When |
variable_type |
character, either |
facet_ncol |
number of columns for arranging subplots. Only used when |
numerical_plot_type |
character, either |
times |
numeric vector, times for which the profile should be plotted, the times must be present in the 'times' field of the explainer. If |
marginalize_over_time |
logical, if |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") |
rug |
character, one of |
rug_colors |
character vector containing two colors (containing either hex codes |
A collection of ggplot
objects arranged with the patchwork
package.
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_profile <- predict_profile(exp, veteran[1, -c(3, 4)]) plot(p_profile) p_profile_with_cat <- predict_profile( exp, veteran[1, -c(3, 4)], categorical_variables = c("trt", "prior") ) plot(p_profile_with_cat)
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_profile <- predict_profile(exp, veteran[1, -c(3, 4)]) plot(p_profile) p_profile_with_cat <- predict_profile( exp, veteran[1, -c(3, 4)], categorical_variables = c("trt", "prior") ) plot(p_profile_with_cat)
This function plots feature importance objects created for survival models using the
model_parts()
function with a time-dependent metric, that is loss_one_minus_cd_auc()
or
loss_brier_score()
.
## S3 method for class 'surv_feature_importance' plot( x, ..., title = "Time-dependent feature importance", subtitle = "default", max_vars = 7, colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
## S3 method for class 'surv_feature_importance' plot( x, ..., title = "Time-dependent feature importance", subtitle = "default", max_vars = 7, colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
x |
an object of class |
... |
additional objects of class |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
max_vars |
maximum number of variables to be plotted (least important variables are ignored) |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") |
rug |
character, one of |
rug_colors |
character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. |
An object of the class ggplot
.
Other functions for plotting 'model_parts_survival' objects:
plot.model_parts_survival()
library(survival) library(survex) model <- coxph(Surv(time, status) ~ ., data = veteran, x = TRUE, model = TRUE, y = TRUE) model_rf <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) explainer <- explain(model) explainer_rf <- explain(model_rf) mp <- model_parts(explainer) mp_rf <- model_parts(explainer_rf) plot(mp, mp_rf)
library(survival) library(survex) model <- coxph(Surv(time, status) ~ ., data = veteran, x = TRUE, model = TRUE, y = TRUE) model_rf <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) explainer <- explain(model) explainer_rf <- explain(model_rf) mp <- model_parts(explainer) mp_rf <- model_parts(explainer_rf) plot(mp, mp_rf)
This functions plots objects of class surv_lime
- LIME explanations of survival models
created using predict_parts(..., type="survlime")
function.
## S3 method for class 'surv_lime' plot( x, type = "local_importance", show_survival_function = TRUE, ..., title = "SurvLIME", subtitle = "default", max_vars = 7, colors = NULL )
## S3 method for class 'surv_lime' plot( x, type = "local_importance", show_survival_function = TRUE, ..., title = "SurvLIME", subtitle = "default", max_vars = 7, colors = NULL )
x |
an object of class |
type |
character, either "coefficients" or "local_importance" (default), selects the type of plot |
show_survival_function |
logical, if the survival function of the explanations should be plotted next to the barplot |
... |
other parameters currently ignored |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
max_vars |
maximum number of variables to be plotted (least important variables are ignored) |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") |
An object of the class ggplot
.
Other functions for plotting 'predict_parts_survival' objects:
plot.predict_parts_survival()
,
plot.surv_shap()
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_parts_lime <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survlime") plot(p_parts_lime)
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_parts_lime <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survlime") plot(p_parts_lime)
This function plots objects of class "surv_model_performance"
- visualization of
metrics of different models created using the model_performance(..., type="metrics")
function.
## S3 method for class 'surv_model_performance' plot( x, ..., metrics = NULL, metrics_type = "time_dependent", title = "Model performance", subtitle = "default", facet_ncol = NULL, colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
## S3 method for class 'surv_model_performance' plot( x, ..., metrics = NULL, metrics_type = "time_dependent", title = "Model performance", subtitle = "default", facet_ncol = NULL, colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
x |
an object of class |
... |
additional objects of class |
metrics |
character, names of metrics to be plotted (subset of C/D AUC", "Brier score" for |
metrics_type |
character, either one of |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
facet_ncol |
number of columns for arranging subplots |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") |
rug |
character, one of |
rug_colors |
character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. |
An object of the class ggplot
.
Other functions for plotting 'model_performance_survival' objects:
plot.model_performance_survival()
,
plot.surv_model_performance_rocs()
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) m_perf <- model_performance(exp) plot(m_perf)
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) m_perf <- model_performance(exp) plot(m_perf)
This function plots objects of class "surv_model_performance_rocs"
- ROC curves for
specific time points for survival models created using the model_performance(..., type="roc")
.
## S3 method for class 'surv_model_performance_rocs' plot( x, ..., title = "ROC curves for selected time points", subtitle = "default", auc = TRUE, colors = NULL, facet_ncol = NULL )
## S3 method for class 'surv_model_performance_rocs' plot( x, ..., title = "ROC curves for selected time points", subtitle = "default", auc = TRUE, colors = NULL, facet_ncol = NULL )
x |
an object of class |
... |
additional objects of class |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
auc |
boolean, whether the AUC values should be plotted |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") |
facet_ncol |
number of columns for arranging subplots |
An object of the class ggplot
.
Other functions for plotting 'model_performance_survival' objects:
plot.model_performance_survival()
,
plot.surv_model_performance()
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) m_perf_roc <- model_performance(exp, type = "roc", times = c(100, 300)) plot(m_perf_roc)
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) m_perf_roc <- model_performance(exp, type = "roc", times = c(100, 300)) plot(m_perf_roc)
This functions plots objects of class surv_shap
- time-dependent explanations of
survival models created using the predict_parts(..., type="survshap")
function.
## S3 method for class 'surv_shap' plot( x, ..., title = "SurvSHAP(t)", subtitle = "default", max_vars = 7, colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
## S3 method for class 'surv_shap' plot( x, ..., title = "SurvSHAP(t)", subtitle = "default", max_vars = 7, colors = NULL, rug = "all", rug_colors = c("#dd0000", "#222222") )
x |
an object of class |
... |
additional objects of class |
title |
character, title of the plot |
subtitle |
character, subtitle of the plot, |
max_vars |
maximum number of variables to be plotted (least important variables are ignored) |
colors |
character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue") |
rug |
character, one of |
rug_colors |
character vector containing two colors (containing either hex codes "#FF69B4", or names "blue"). The first color (red by default) will be used to mark event times, whereas the second (grey by default) will be used to mark censor times. |
An object of the class ggplot
.
Other functions for plotting 'predict_parts_survival' objects:
plot.predict_parts_survival()
,
plot.surv_lime()
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_parts_shap <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survshap") plot(p_parts_shap)
library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_parts_shap <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survshap") plot(p_parts_shap)
This function decomposes the model prediction into individual parts, which are attributions of particular variables. The explanations can be made via the SurvLIME and SurvSHAP(t) methods.
predict_parts(explainer, ...) ## S3 method for class 'surv_explainer' predict_parts( explainer, new_observation, ..., N = NULL, type = "survshap", output_type = "survival", explanation_label = NULL )
predict_parts(explainer, ...) ## S3 method for class 'surv_explainer' predict_parts( explainer, new_observation, ..., N = NULL, type = "survshap", output_type = "survival", explanation_label = NULL )
explainer |
an explainer object - model preprocessed by the |
... |
other parameters which are passed to |
new_observation |
a new observation for which prediction need to be explained |
N |
the number of observations used for calculation of attributions. If |
type |
if |
output_type |
either |
explanation_label |
a label that can overwrite explainer label (useful for multiple explanations for the same explainer/model) |
An object of class "predict_parts_survival"
and additional classes depending on the type of explanations. It is a list with the element result
containing the results of the calculation.
There are additional parameters that are passed to internal functions
for survlime
N
- a positive integer, number of observations generated in the neighbourhood
distance_metric
- character, name of the distance metric to be used, only "euclidean"
is implemented
kernel_width
- a numeric or "silverman"
, parameter used for calculating weights, by default it's sqrt(ncol(data)*0.75)
. If "silverman"
the kernel width is calculated using the method proposed by Silverman and used in the SurvLIMEpy Python package.
sampling_method
- character, name of the method of generating neighbourhood, only "gaussian"
is implemented
sample_around_instance
- logical, if the neighbourhood should be generated with the new observation as the center (default), or should the mean of the whole dataset be used as the center
max_iter
- a numeric, maximal number of iteration for the optimization problem
categorical_variables
- character vector, names of variables that should be treated as categories (factors are included by default)
k
- a small positive number > 1, added to chf before taking log, so that weigths aren't negative
for survshap
y_true
- a two element numeric vector or matrix of one row and two columns, the first element being the true observed time and the second the status of the observation, used for plotting
calculation_method
- a character, either "kernelshap"
for use of kernelshap
library (providing faster Kernel SHAP with refinements) or "exact_kernel"
for exact Kernel SHAP estimation
aggregation_method
- a character, either "mean_absolute"
or "integral"
, "max_absolute"
, "sum_of_squares"
[1] Krzyziński, Mateusz, et al. "SurvSHAP(t): Time-dependent explanations of machine learning survival models." Knowledge-Based Systems 262 (2023): 110234
[2] Kovalev, Maxim S., et al. "SurvLIME: A method for explaining machine learning survival models." Knowledge-Based Systems 203 (2020): 106164.
[3] Pachón-García, Cristian, et al. "SurvLIMEpy: A Python package implementing SurvLIME." Expert Systems with Applications 237 (2024): 121620.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) cph_predict_parts_survshap <- predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)]) head(cph_predict_parts_survshap$result) plot(cph_predict_parts_survshap) cph_predict_parts_survlime <- predict_parts( cph_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime" ) head(cph_predict_parts_survlime$result) plot(cph_predict_parts_survlime, type = "local_importance")
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) cph_exp <- explain(cph) cph_predict_parts_survshap <- predict_parts(cph_exp, new_observation = veteran[1, -c(3, 4)]) head(cph_predict_parts_survshap$result) plot(cph_predict_parts_survshap) cph_predict_parts_survlime <- predict_parts( cph_exp, new_observation = veteran[1, -c(3, 4)], type = "survlime" ) head(cph_predict_parts_survlime$result) plot(cph_predict_parts_survlime, type = "local_importance")
This function calculates Ceteris Paribus Profiles for a specific observation with the possibility to take the time dimension into account.
predict_profile( explainer, new_observation, variables = NULL, categorical_variables = NULL, ..., type = "ceteris_paribus", output_type = "survival", variable_splits_type = "uniform", center = FALSE ) ## S3 method for class 'surv_explainer' predict_profile( explainer, new_observation, variables = NULL, categorical_variables = NULL, ..., type = "ceteris_paribus", output_type = "survival", variable_splits_type = "uniform", center = FALSE )
predict_profile( explainer, new_observation, variables = NULL, categorical_variables = NULL, ..., type = "ceteris_paribus", output_type = "survival", variable_splits_type = "uniform", center = FALSE ) ## S3 method for class 'surv_explainer' predict_profile( explainer, new_observation, variables = NULL, categorical_variables = NULL, ..., type = "ceteris_paribus", output_type = "survival", variable_splits_type = "uniform", center = FALSE )
explainer |
an explainer object - model preprocessed by the |
new_observation |
a new observation for which the prediction need to be explained |
variables |
a character vector containing names of variables to be explained |
categorical_variables |
a character vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the |
... |
additional parameters passed to |
type |
character, only |
output_type |
either |
variable_splits_type |
character, decides how variable grids should be calculated. Use |
center |
logical, should profiles be centered around the average prediction |
An object of class c("predict_profile_survival", "surv_ceteris_paribus")
. It is a list with the final result in the result
element.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph) rsf_src_exp <- explain(rsf_src) cph_predict_profile <- predict_profile(cph_exp, veteran[2, -c(3, 4)], variables = c("trt", "celltype", "karno", "age"), categorical_variables = "trt" ) plot(cph_predict_profile, facet_ncol = 2) rsf_predict_profile <- predict_profile(rsf_src_exp, veteran[5, -c(3, 4)], variables = "karno") plot(cph_predict_profile, numerical_plot_type = "contours")
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph) rsf_src_exp <- explain(rsf_src) cph_predict_profile <- predict_profile(cph_exp, veteran[2, -c(3, 4)], variables = c("trt", "celltype", "karno", "age"), categorical_variables = "trt" ) plot(cph_predict_profile, facet_ncol = 2) rsf_predict_profile <- predict_profile(rsf_src_exp, veteran[5, -c(3, 4)], variables = "karno") plot(cph_predict_profile, numerical_plot_type = "contours")
This function allows for calculating model prediction in a unified way.
## S3 method for class 'surv_explainer' predict(object, newdata = NULL, output_type = "survival", times = NULL, ...)
## S3 method for class 'surv_explainer' predict(object, newdata = NULL, output_type = "survival", times = NULL, ...)
object |
an explainer object - model preprocessed by the |
newdata |
data used for the prediction |
output_type |
character, either |
times |
a numeric vector of times for the survival and cumulative hazard function predictions to be evaluated at. If |
... |
other arguments, currently ignored |
A vector or matrix containing the prediction.
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) predict(cph_exp, veteran[1, ], output_type = "survival")[, 1:10] predict(cph_exp, veteran[1, ], output_type = "risk") predict(rsf_ranger_exp, veteran[1, ], output_type = "chf")[, 1:10]
library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) predict(cph_exp, veteran[1, ], output_type = "survival")[, 1:10] predict(cph_exp, veteran[1, ], output_type = "risk") predict(rsf_ranger_exp, veteran[1, ], output_type = "chf")[, 1:10]
Some models do not come with a ready to use risk prediction. This function allows for its generation based on the cumulative hazard function.
risk_from_chf(predict_cumulative_hazard_function, times)
risk_from_chf(predict_cumulative_hazard_function, times)
predict_cumulative_hazard_function |
a function of three arguments ( |
times |
a numeric vector of times at which the function should be evaluated. |
A function of two arguments (model
, newdata
) returning a vector of risks.
library(survex) library(survival) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) chf_function <- transform_to_stepfunction(predict, type = "chf", prediction_element = "chf", times_element = "time.interest" ) risk_function <- risk_from_chf(chf_function, unique(veteran$time)) explainer <- explain(rsf_src, predict_cumulative_hazard_function = chf_function, predict_function = risk_function )
library(survex) library(survival) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) chf_function <- transform_to_stepfunction(predict, type = "chf", prediction_element = "chf", times_element = "time.interest" ) risk_function <- risk_from_chf(chf_function, unique(veteran$time)) explainer <- explain(rsf_src, predict_cumulative_hazard_function = chf_function, predict_function = risk_function )
Default Theme for survex plots
set_theme_survex( default_theme = "drwhy", default_theme_vertical = default_theme ) theme_default_survex() theme_vertical_default_survex()
set_theme_survex( default_theme = "drwhy", default_theme_vertical = default_theme ) theme_default_survex() theme_vertical_default_survex()
default_theme |
object - string ("drwhy" or "ema") or an object of ggplot theme class. Will be applied by default by |
default_theme_vertical |
object - string ("drwhy" or "ema") or an object of ggplot theme class. Will be applied by default by |
list with current default themes
old <- set_theme_survex("ema") library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_parts_lime <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survlime") old <- set_theme_survex("drwhy") plot(p_parts_lime) old <- set_theme_survex(ggplot2::theme_void(), ggplot2::theme_void()) plot(p_parts_lime)
old <- set_theme_survex("ema") library(survival) library(survex) model <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) exp <- explain(model) p_parts_lime <- predict_parts(exp, veteran[1, -c(3, 4)], type = "survlime") old <- set_theme_survex("drwhy") plot(p_parts_lime) old <- set_theme_survex(ggplot2::theme_void(), ggplot2::theme_void()) plot(p_parts_lime)
This generic function let user extract base information about model. The function returns a named list of class model_info
that
contain information about package of model, version and task type. For wrappers like mlr
or parsnip
both, package and wrapper
information are stored
surv_model_info(model, ...) ## S3 method for class 'coxph' surv_model_info(model, ...) ## S3 method for class 'rfsrc' surv_model_info(model, ...) ## S3 method for class 'ranger' surv_model_info(model, ...) ## S3 method for class 'model_fit' surv_model_info(model, ...) ## S3 method for class 'cph' surv_model_info(model, ...) ## S3 method for class 'LearnerSurv' surv_model_info(model, ...) ## S3 method for class 'sksurv' surv_model_info(model, ...) ## S3 method for class 'flexsurvreg' surv_model_info(model, ...) ## Default S3 method: surv_model_info(model, ...)
surv_model_info(model, ...) ## S3 method for class 'coxph' surv_model_info(model, ...) ## S3 method for class 'rfsrc' surv_model_info(model, ...) ## S3 method for class 'ranger' surv_model_info(model, ...) ## S3 method for class 'model_fit' surv_model_info(model, ...) ## S3 method for class 'cph' surv_model_info(model, ...) ## S3 method for class 'LearnerSurv' surv_model_info(model, ...) ## S3 method for class 'sksurv' surv_model_info(model, ...) ## S3 method for class 'flexsurvreg' surv_model_info(model, ...) ## Default S3 method: surv_model_info(model, ...)
model |
|
... |
|
Currently supported packages are:
class coxph
- Cox proportional hazards regression model created with survival package
class model_fit
- models created with parsnip package
class ranger
- random survival forest models created with ranger package
class rfsrc
- random forest models created with randomForestSRC package
A named list of class model_info
library(survival) library(survex) cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE ) surv_model_info(cph) library(ranger) rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, num.trees = 50, mtry = 3, max.depth = 5 ) surv_model_info(rsf_ranger)
library(survival) library(survex) cph <- survival::coxph(survival::Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE ) surv_model_info(cph) library(ranger) rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, num.trees = 50, mtry = 3, max.depth = 5 ) surv_model_info(rsf_ranger)
Helper function to transform between survival function and CHF
survival_to_cumulative_hazard(survival_functions, epsilon = 0)
survival_to_cumulative_hazard(survival_functions, epsilon = 0)
survival_functions |
matrix or vector, with each row representing a survival function |
epsilon |
a positive numeric number to add, so that the logarithm can be taken |
A matrix or vector transformed to the form of a cumulative hazard function.
library(survex) vec <- c(1, 0.9, 0.8, 0.7, 0.6) matr <- matrix(c(1, 0.9, 0.8, 1, 0.8, 0.6), ncol = 3) survival_to_cumulative_hazard(vec) survival_to_cumulative_hazard(matr)
library(survex) vec <- c(1, 0.9, 0.8, 0.7, 0.6) matr <- matrix(c(1, 0.9, 0.8, 1, 0.8, 0.6), ncol = 3) survival_to_cumulative_hazard(vec) survival_to_cumulative_hazard(matr)
Some models return the survival function or cumulative hazard function prediction at the times of events present in the training data set. This is a convenient utility to allow the prediction to be evaluated at any time.
transform_to_stepfunction( predict_function, eval_times = NULL, ..., type = NULL, prediction_element = NULL, times_element = NULL )
transform_to_stepfunction( predict_function, eval_times = NULL, ..., type = NULL, prediction_element = NULL, times_element = NULL )
predict_function |
a function making the prediction based on |
eval_times |
a numeric vector of times, at which the fixed predictions are made. This can be |
... |
other parameters passed to predict_function |
type |
the type of function to be returned, either |
prediction_element |
if |
times_element |
if |
The function returns a function with three arguments, (model
, newdata
, times
), ready to supply it to an explainer.
library(survex) library(survival) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) chf_function <- transform_to_stepfunction(predict, type = "chf", prediction_element = "chf", times_element = "time.interest" ) explainer <- explain(rsf_src, predict_cumulative_hazard_function = chf_function)
library(survex) library(survival) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) chf_function <- transform_to_stepfunction(predict, type = "chf", prediction_element = "chf", times_element = "time.interest" ) explainer <- explain(rsf_src, predict_cumulative_hazard_function = chf_function)