Introduction

Here we present an example for classification. The status variable in the HR dataset is a factor that takes three values fired/ok/promoted. We will create a model for this variable and then we present explainers for this model.

##   gender      age    hours evaluation salary   status
## 1   male 32.58267 41.88626          3      1    fired
## 2 female 41.21104 36.34339          2      5    fired
## 3   male 37.70516 36.81718          3      0    fired
## 4 female 30.06051 38.96032          3      2    fired
## 5   male 21.10283 62.15464          5      3 promoted
## 6   male 40.11812 69.53973          2      0    fired

Here we create a random forest model for this dataset.

library("randomForest")
library("ceterisParibus")
set.seed(59)

model <- randomForest(status ~ gender + age + hours + evaluation + salary, data = HR)

By default the predict.randomForest() function returns classes not scores. This is why we use user-specific predict() function. Here we use two explainers, one will explainer the fired class while the second will take care about the promoted class.

Multiclass

Ceteris Paribus package can plot many explainers in a single panel. This will be useful for multiclass classification. You can simply create an explainer for each class and plot all these explainers together.

Individial curves

The explainer_rf_fired explainer is focused on class fired. Let’s see Ceteris Paribus profiles for first 10 individuals.

They are colored with the gender variable. It’s useful since in the model there is an interaction between age and gender. Can you spot it?

Partial Dependency Plots

We can plot explainers for two classes (fired and promoted) in the same plot. Here is an example how to do this.

The hours variable is an interesting one. People with low average working hours are on average more likely to be fired not promoted. Also the salary seems to be related with probability of being fired but not promoted.

Partial Dependency Plots for groups of obervations

Here we will compare the explainer (probability of being fired) across two genders. We know that there is an interaction between gender and age variables. It will be easy to read this interaction from the plot.

In the plot below we will see that both genders behave in all panels except age. For the age variable it looks like younger woman are more likely to be fired while for males the older age is a risk factor.

Is it because employers are afraid of maternity leaves? Maybe, but please note, that this dataset is an artificial/simulated one.

More models

Here we have examples for three classification models: random forest, support vector machines, generalized linear model. We focus only on the ‘fired’ class.

Let’s start with examples for a single observation. We need to train models first. Note that predictions are calculated in a different way for different functions, thus we need to prepare moedl specific predict functions.

Here is for the random forest.

Here is for generalized linear model.

And here is one for support vector machines.

Having explainers, we can now create Ceteris Paribus profiles for a selected single observation.

new_obs <- HR[1,]

cp_rf <- ceteris_paribus(explainer_rf_fired, new_obs)
cp_glm <- ceteris_paribus(explainer_glm_fired, new_obs)
cp_svm <- ceteris_paribus(explainer_svm_fired, new_obs)

And we can plot profiles for this observations. Note that both SVM and RF captured the interaction between gender and age. Also they captured the nonlinear relation for salary.

Now we can calculate average profile for global explanations. Here the average will be calculated based on 100 observations. Note that since age and gender are in interaction, on average there is no relation between score and the age.

Session info

## R version 3.5.0 (2018-04-23)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS Sierra 10.12.6
## 
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] e1071_1.7-0          ceterisParibus_0.3.1 gower_0.1.2         
## [4] ggplot2_3.1.0        randomForest_4.6-14  DALEX_0.2.5         
## 
## loaded via a namespace (and not attached):
##  [1] splines_3.5.0      gtools_3.5.0       shiny_1.2.0       
##  [4] assertthat_0.2.0   expm_0.999-2       sp_1.2-7          
##  [7] highr_0.6          pdp_0.7.0          yaml_2.2.0        
## [10] LearnBayes_2.15.1  pillar_1.3.0       backports_1.1.2   
## [13] lattice_0.20-35    glue_1.3.0         digest_0.6.18     
## [16] promises_1.0.1     colorspace_1.3-2   htmltools_0.3.6   
## [19] httpuv_1.4.5       Matrix_1.2-14      plyr_1.8.4        
## [22] klaR_0.6-14        pkgconfig_2.0.2    breakDown_0.2.0   
## [25] questionr_0.6.2    gmodels_2.16.2     purrr_0.2.5       
## [28] xtable_1.8-3       mvtnorm_1.0-7      scales_1.0.0      
## [31] gdata_2.18.0       later_0.7.5        proxy_0.4-22      
## [34] tibble_1.4.2       combinat_0.0-8     ggpubr_0.1.8      
## [37] withr_2.1.2        ALEPlot_1.1        agricolae_1.2-8   
## [40] lazyeval_0.2.1     survival_2.41-3    magrittr_1.5      
## [43] crayon_1.3.4       mime_0.6           deldir_0.1-15     
## [46] memoise_1.1.0      evaluate_0.10.1    fs_1.2.2          
## [49] nlme_3.1-137       MASS_7.3-49        class_7.3-14      
## [52] xml2_1.2.0         tools_3.5.0        stringr_1.3.1     
## [55] munsell_0.5.0      cluster_2.0.7-1    bindrcpp_0.2.2    
## [58] compiler_3.5.0     pkgdown_1.0.0      rlang_0.3.0.1     
## [61] grid_3.5.0         rstudioapi_0.7     miniUI_0.1.1.1    
## [64] labeling_0.3       rmarkdown_1.10     boot_1.3-20       
## [67] gtable_0.2.0       roxygen2_6.1.1     reshape2_1.4.3    
## [70] R6_2.3.0           AlgDesign_1.1-7.3  gridExtra_2.3     
## [73] yaImpute_1.0-29    knitr_1.20         dplyr_0.7.8       
## [76] bindr_0.1.1        commonmark_1.5     factorMerger_0.3.6
## [79] rprojroot_1.3-2    spdep_0.7-7        desc_1.2.0        
## [82] stringi_1.2.4      Rcpp_1.0.0         spData_0.2.8.3    
## [85] tidyselect_0.2.5   coda_0.19-1