This function calculates ceteris paribus profiles, i.e. series of predictions from a model calculated for observations with altered single coordinate.
calculate_profiles(data, variable_splits, model, predict_function = predict, ...)
data | set of observations. Profile will be calculated for every observation (every row) |
---|---|
variable_splits | named list of vectors. Elements of the list are vectors with points in which profiles should be calculated. See an example for more details. |
model | a model that will be passed to the |
predict_function | function that takes data and model and returns numeric predictions. Note that the ... arguments will be passed to this function. |
... | other parameters that will be passed to the |
a data frame with profiles for selected variables and selected observations
Note that calculate_profiles
function is S3 generic.
If you want to work on non standard data sources (like H2O ddf, external databases)
you should overload it.
library("DALEX")library("randomForest") set.seed(59) apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = apartments) vars <- c("construction.year", "surface", "floor", "no.rooms", "district") variable_splits <- calculate_variable_splits(apartments, vars) new_apartment <- apartmentsTest[1:10, ] profiles <- calculate_profiles(new_apartment, variable_splits, apartments_rf_model) profiles#> m2.price construction.year surface floor no.rooms district _yhat_ #> 1001 4644 1920 131 3 5 Srodmiescie 4255.354 #> 1001.1 4644 1921 131 3 5 Srodmiescie 4300.702 #> 1001.2 4644 1922 131 3 5 Srodmiescie 4301.926 #> 1001.3 4644 1923 131 3 5 Srodmiescie 4305.352 #> 1001.4 4644 1924 131 3 5 Srodmiescie 4267.723 #> 1001.5 4644 1925 131 3 5 Srodmiescie 4264.109 #> _vname_ _ids_ #> 1001 construction.year 1001 #> 1001.1 construction.year 1001 #> 1001.2 construction.year 1001 #> 1001.3 construction.year 1001 #> 1001.4 construction.year 1001 #> 1001.5 construction.year 1001# only subset of observations small_apartments <- select_sample(apartmentsTest, n = 10) small_apartments#> m2.price construction.year surface floor no.rooms district #> 8946 2174 1959 123 8 4 Wola #> 4458 4319 1927 68 8 2 Ochota #> 7384 5501 1929 95 5 3 Srodmiescie #> 5450 2810 1982 124 10 5 Ochota #> 6744 1770 1982 143 9 6 Ursynow #> 6688 2796 1938 75 7 3 Wola #> 3167 5701 1971 55 1 3 Srodmiescie #> 1902 2672 1977 98 6 3 Ursus #> 5925 3916 1924 33 7 1 Bemowo #> 4293 3474 1979 113 5 4 Mokotowsmall_profiles <- calculate_profiles(small_apartments, variable_splits, apartments_rf_model) small_profiles#> m2.price construction.year surface floor no.rooms district _yhat_ #> 8946 2174 1920 123 8 4 Wola 2871.626 #> 8946.1 2174 1921 123 8 4 Wola 2896.045 #> 8946.2 2174 1922 123 8 4 Wola 2901.677 #> 8946.3 2174 1923 123 8 4 Wola 2891.101 #> 8946.4 2174 1924 123 8 4 Wola 2890.361 #> 8946.5 2174 1925 123 8 4 Wola 2891.720 #> _vname_ _ids_ #> 8946 construction.year 8946 #> 8946.1 construction.year 8946 #> 8946.2 construction.year 8946 #> 8946.3 construction.year 8946 #> 8946.4 construction.year 8946 #> 8946.5 construction.year 8946# neighbors for a selected observation new_apartment <- apartments[1, 2:6] small_apartments <- select_neighbours(apartmentsTest, new_apartment, n = 10) small_apartments#> m2.price construction.year surface floor no.rooms district #> 2285 5875 1970 27 3 1 Srodmiescie #> 1073 5886 1960 36 2 1 Srodmiescie #> 8110 5614 1957 44 4 1 Srodmiescie #> 9527 6080 1947 27 1 1 Srodmiescie #> 3261 5859 1945 39 2 1 Srodmiescie #> 4309 5794 1947 31 3 2 Srodmiescie #> 1198 5821 1947 43 2 1 Srodmiescie #> 6647 5952 1938 30 2 1 Srodmiescie #> 4027 6457 1926 29 3 1 Srodmiescie #> 2655 5596 1950 25 6 1 Srodmiesciesmall_profiles <- calculate_profiles(small_apartments, variable_splits, apartments_rf_model) new_apartment#> construction.year surface floor no.rooms district #> 1 1953 25 3 1 Srodmiesciesmall_profiles#> m2.price construction.year surface floor no.rooms district _yhat_ #> 2285 5875 1920 27 3 1 Srodmiescie 5438.443 #> 2285.1 5875 1921 27 3 1 Srodmiescie 5478.624 #> 2285.2 5875 1922 27 3 1 Srodmiescie 5477.707 #> 2285.3 5875 1923 27 3 1 Srodmiescie 5494.789 #> 2285.4 5875 1924 27 3 1 Srodmiescie 5501.781 #> 2285.5 5875 1925 27 3 1 Srodmiescie 5468.007 #> _vname_ _ids_ #> 2285 construction.year 2285 #> 2285.1 construction.year 2285 #> 2285.2 construction.year 2285 #> 2285.3 construction.year 2285 #> 2285.4 construction.year 2285 #> 2285.5 construction.year 2285