ceteris_paribus_HR.Rmd
Here we present an example for classification. The status
variable in the HR
dataset is a factor that takes three values fired
/ok
/promoted
. We will create a model for this variable and then we present explainers for this model.
## gender age hours evaluation salary status
## 1 male 32.58267 41.88626 3 1 fired
## 2 female 41.21104 36.34339 2 5 fired
## 3 male 37.70516 36.81718 3 0 fired
## 4 female 30.06051 38.96032 3 2 fired
## 5 male 21.10283 62.15464 5 3 promoted
## 6 male 40.11812 69.53973 2 0 fired
Here we create a random forest model for this dataset.
library("randomForest")
library("ceterisParibus")
set.seed(59)
model <- randomForest(status ~ gender + age + hours + evaluation + salary, data = HR)
By default the predict.randomForest()
function returns classes not scores. This is why we use user-specific predict()
function. Here we use two explainers, one will explainer the fired
class while the second will take care about the promoted
class.
pred1 <- function(m, x) predict(m, x, type = "prob")[,1]
pred2 <- function(m, x) predict(m, x, type = "prob")[,2]
pred3 <- function(m, x) predict(m, x, type = "prob")[,3]
explainer_rf_fired <- explain(model, data = HR[,1:5],
y = HR$status == "fired",
predict_function = pred1, label = "fired")
explainer_rf_ok <- explain(model, data = HR[,1:5],
y = HR$status == "ok",
predict_function = pred2, label = "ok")
explainer_rf_promoted <- explain(model, data = HR[,1:5],
y = HR$status == "promoted",
predict_function = pred3, label = "promoted")
Ceteris Paribus package can plot many explainers in a single panel. This will be useful for multiclass classification. You can simply create an explainer for each class and plot all these explainers together.
The explainer_rf_fired
explainer is focused on class fired
. Let’s see Ceteris Paribus profiles for first 10 individuals.
They are colored with the gender
variable. It’s useful since in the model there is an interaction between age and gender. Can you spot it?
We can plot explainers for two classes (fired and promoted) in the same plot. Here is an example how to do this.
The hours
variable is an interesting one. People with low average working hours are on average more likely to be fired not promoted. Also the salary
seems to be related with probability of being fired
but not promoted
.
Here we will compare the explainer (probability of being fired) across two genders. We know that there is an interaction between gender
and age
variables. It will be easy to read this interaction from the plot.
In the plot below we will see that both genders behave in all panels except age. For the age
variable it looks like younger woman are more likely to be fired while for males the older age is a risk factor.
Is it because employers are afraid of maternity leaves? Maybe, but please note, that this dataset is an artificial/simulated one.
cp_rfF <- ceteris_paribus(explainer_rf_fired,
HR[which(HR$gender == "female")[1:100],])
cp_rfF$`_label_` = "Fired Female"
cp_rfM <- ceteris_paribus(explainer_rf_fired,
HR[which(HR$gender == "male")[1:100],])
cp_rfM$`_label_` = "Fired Male"
plot(cp_rfM, cp_rfF,
aggregate_profiles = mean,
alpha = 1, show_observations = FALSE, color="_label_")
Here we have examples for three classification models: random forest, support vector machines, generalized linear model. We focus only on the ‘fired’ class.
Let’s start with examples for a single observation. We need to train models first. Note that predictions are calculated in a different way for different functions, thus we need to prepare moedl specific predict functions.
Here is for the random forest.
library("ceterisParibus")
library("e1071")
model_rf <- randomForest(status ~ gender + age + hours + evaluation + salary, data = HR)
pred_rf_fired <- function(m, x) predict(m, x, type = "prob")[,1]
explainer_rf_fired <- explain(model, data = HR[,1:5],
y = HR$status == "fired",
predict_function = pred_rf_fired)
Here is for generalized linear model.
model_fired <- glm(status == "fired" ~ gender + age + hours + evaluation + salary, data = HR, family = "binomial")
pred_glm_fired <- function(m, x) predict.glm(m, x, type = "response")
explainer_glm_fired <- explain(model_fired, data = HR[,1:5],
y = HR$status == "fired",
predict_function = pred_glm_fired)
And here is one for support vector machines.
model_svm_fired <- svm(status ~ gender + age + hours + evaluation + salary, data = HR, probability = TRUE)
pred_svm_fired <- function(m, x) attr(predict(m, x, probability = TRUE), "probabilities")[,1]
explainer_svm_fired <- explain(model_svm_fired, data = HR[,1:5],
y = HR$status == "fired",
predict_function = pred_svm_fired)
Having explainers, we can now create Ceteris Paribus profiles for a selected single observation.
new_obs <- HR[1,]
cp_rf <- ceteris_paribus(explainer_rf_fired, new_obs)
cp_glm <- ceteris_paribus(explainer_glm_fired, new_obs)
cp_svm <- ceteris_paribus(explainer_svm_fired, new_obs)
And we can plot profiles for this observations. Note that both SVM and RF captured the interaction between gender and age. Also they captured the nonlinear relation for salary.
Now we can calculate average profile for global explanations. Here the average will be calculated based on 100 observations. Note that since age and gender are in interaction, on average there is no relation between score and the age.
Images and codes used in the cheatsheet
percent <- function (x)
paste0(format(round(x, 2) * 100, big.mark = ",", scientific = FALSE, trim = TRUE), "%")
data_1 <- HR[1,]
cp_rf3 <- ceteris_paribus(explainer_rf_promoted, data_1)
plot(cp_rf3,
alpha = 0.5, size_points = 4, selected_variables = "hours",
as.gg = TRUE) + xlab("") + scale_y_continuous(limits=c(0,1), name = "Pr(promoted)", labels =percent) + theme_light()
cp_rf1 <- ceteris_paribus(explainer_rf_fired, HR[1,])
cp_rf2 <- ceteris_paribus(explainer_rf_ok, HR[1,])
cp_rf3 <- ceteris_paribus(explainer_rf_promoted, HR[1,])
plot(cp_rf1, cp_rf2, cp_rf3,
alpha = 0.5, color="_label_", size_points = 4, selected_variables = "hours",
as.gg = TRUE) + xlab("") + scale_y_continuous(limits=c(0,1), name = "Pr(promoted)", labels = percent) + theme_light() + scale_color_discrete(name = "class")
## R version 3.5.0 (2018-04-23)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS Sierra 10.12.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] e1071_1.7-0 ceterisParibus_0.3.1 gower_0.1.2
## [4] ggplot2_3.1.0 randomForest_4.6-14 DALEX_0.2.5
##
## loaded via a namespace (and not attached):
## [1] splines_3.5.0 gtools_3.5.0 shiny_1.2.0
## [4] assertthat_0.2.0 expm_0.999-2 sp_1.2-7
## [7] highr_0.6 pdp_0.7.0 yaml_2.2.0
## [10] LearnBayes_2.15.1 pillar_1.3.0 backports_1.1.2
## [13] lattice_0.20-35 glue_1.3.0 digest_0.6.18
## [16] promises_1.0.1 colorspace_1.3-2 htmltools_0.3.6
## [19] httpuv_1.4.5 Matrix_1.2-14 plyr_1.8.4
## [22] klaR_0.6-14 pkgconfig_2.0.2 breakDown_0.2.0
## [25] questionr_0.6.2 gmodels_2.16.2 purrr_0.2.5
## [28] xtable_1.8-3 mvtnorm_1.0-7 scales_1.0.0
## [31] gdata_2.18.0 later_0.7.5 proxy_0.4-22
## [34] tibble_1.4.2 combinat_0.0-8 ggpubr_0.1.8
## [37] withr_2.1.2 ALEPlot_1.1 agricolae_1.2-8
## [40] lazyeval_0.2.1 survival_2.41-3 magrittr_1.5
## [43] crayon_1.3.4 mime_0.6 deldir_0.1-15
## [46] memoise_1.1.0 evaluate_0.10.1 fs_1.2.2
## [49] nlme_3.1-137 MASS_7.3-49 class_7.3-14
## [52] xml2_1.2.0 tools_3.5.0 stringr_1.3.1
## [55] munsell_0.5.0 cluster_2.0.7-1 bindrcpp_0.2.2
## [58] compiler_3.5.0 pkgdown_1.0.0 rlang_0.3.0.1
## [61] grid_3.5.0 rstudioapi_0.7 miniUI_0.1.1.1
## [64] labeling_0.3 rmarkdown_1.10 boot_1.3-20
## [67] gtable_0.2.0 roxygen2_6.1.1 reshape2_1.4.3
## [70] R6_2.3.0 AlgDesign_1.1-7.3 gridExtra_2.3
## [73] yaImpute_1.0-29 knitr_1.20 dplyr_0.7.8
## [76] bindr_0.1.1 commonmark_1.5 factorMerger_0.3.6
## [79] rprojroot_1.3-2 spdep_0.7-7 desc_1.2.0
## [82] stringi_1.2.4 Rcpp_1.0.0 spData_0.2.8.3
## [85] tidyselect_0.2.5 coda_0.19-1