This function calculates ceteris paribus profiles, i.e. series of predictions from a model calculated for observations with altered single coordinate.

calculate_profiles(data, variable_splits, model,
  predict_function = predict, ...)

Arguments

data

set of observations. Profile will be calculated for every observation (every row)

variable_splits

named list of vectors. Elements of the list are vectors with points in which profiles should be calculated. See an example for more details.

model

a model that will be passed to the predict_function

predict_function

function that takes data and model and returns numeric predictions. Note that the ... arguments will be passed to this function.

...

other parameters that will be passed to the predict_function

Value

a data frame with profiles for selected variables and selected observations

Details

Note that calculate_profiles function is S3 generic. If you want to work on non standard data sources (like H2O ddf, external databases) you should overload it.

Examples

library("DALEX")
library("randomForest") set.seed(59) apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = apartments) vars <- c("construction.year", "surface", "floor", "no.rooms", "district") variable_splits <- calculate_variable_splits(apartments, vars) new_apartment <- apartmentsTest[1:10, ] profiles <- calculate_profiles(new_apartment, variable_splits, apartments_rf_model) profiles
#> m2.price construction.year surface floor no.rooms district _yhat_ #> 1001 4644 1920 131 3 5 Srodmiescie 4255.354 #> 1001.1 4644 1921 131 3 5 Srodmiescie 4300.702 #> 1001.2 4644 1922 131 3 5 Srodmiescie 4301.926 #> 1001.3 4644 1923 131 3 5 Srodmiescie 4305.352 #> 1001.4 4644 1924 131 3 5 Srodmiescie 4267.723 #> 1001.5 4644 1925 131 3 5 Srodmiescie 4264.109 #> _vname_ _ids_ #> 1001 construction.year 1001 #> 1001.1 construction.year 1001 #> 1001.2 construction.year 1001 #> 1001.3 construction.year 1001 #> 1001.4 construction.year 1001 #> 1001.5 construction.year 1001
# only subset of observations small_apartments <- select_sample(apartmentsTest, n = 10) small_apartments
#> m2.price construction.year surface floor no.rooms district #> 8946 2174 1959 123 8 4 Wola #> 4458 4319 1927 68 8 2 Ochota #> 7384 5501 1929 95 5 3 Srodmiescie #> 5450 2810 1982 124 10 5 Ochota #> 6744 1770 1982 143 9 6 Ursynow #> 6688 2796 1938 75 7 3 Wola #> 3167 5701 1971 55 1 3 Srodmiescie #> 1902 2672 1977 98 6 3 Ursus #> 5925 3916 1924 33 7 1 Bemowo #> 4293 3474 1979 113 5 4 Mokotow
small_profiles <- calculate_profiles(small_apartments, variable_splits, apartments_rf_model) small_profiles
#> m2.price construction.year surface floor no.rooms district _yhat_ #> 8946 2174 1920 123 8 4 Wola 2871.626 #> 8946.1 2174 1921 123 8 4 Wola 2896.045 #> 8946.2 2174 1922 123 8 4 Wola 2901.677 #> 8946.3 2174 1923 123 8 4 Wola 2891.101 #> 8946.4 2174 1924 123 8 4 Wola 2890.361 #> 8946.5 2174 1925 123 8 4 Wola 2891.720 #> _vname_ _ids_ #> 8946 construction.year 8946 #> 8946.1 construction.year 8946 #> 8946.2 construction.year 8946 #> 8946.3 construction.year 8946 #> 8946.4 construction.year 8946 #> 8946.5 construction.year 8946
# neighbors for a selected observation new_apartment <- apartments[1, 2:6] small_apartments <- select_neighbours(apartmentsTest, new_apartment, n = 10) small_apartments
#> m2.price construction.year surface floor no.rooms district #> 2285 5875 1970 27 3 1 Srodmiescie #> 1073 5886 1960 36 2 1 Srodmiescie #> 8110 5614 1957 44 4 1 Srodmiescie #> 9527 6080 1947 27 1 1 Srodmiescie #> 3261 5859 1945 39 2 1 Srodmiescie #> 4309 5794 1947 31 3 2 Srodmiescie #> 1198 5821 1947 43 2 1 Srodmiescie #> 6647 5952 1938 30 2 1 Srodmiescie #> 4027 6457 1926 29 3 1 Srodmiescie #> 2655 5596 1950 25 6 1 Srodmiescie
small_profiles <- calculate_profiles(small_apartments, variable_splits, apartments_rf_model) new_apartment
#> construction.year surface floor no.rooms district #> 1 1953 25 3 1 Srodmiescie
small_profiles
#> m2.price construction.year surface floor no.rooms district _yhat_ #> 2285 5875 1920 27 3 1 Srodmiescie 5438.443 #> 2285.1 5875 1921 27 3 1 Srodmiescie 5478.624 #> 2285.2 5875 1922 27 3 1 Srodmiescie 5477.707 #> 2285.3 5875 1923 27 3 1 Srodmiescie 5494.789 #> 2285.4 5875 1924 27 3 1 Srodmiescie 5501.781 #> 2285.5 5875 1925 27 3 1 Srodmiescie 5468.007 #> _vname_ _ids_ #> 2285 construction.year 2285 #> 2285.1 construction.year 2285 #> 2285.2 construction.year 2285 #> 2285.3 construction.year 2285 #> 2285.4 construction.year 2285 #> 2285.5 construction.year 2285