titanic-tidymodels: boost

library(tidyverse)  # data wrangling
library(tidymodels)  # modelling
library(broom)  # tidy model output
library(skimr)  # overview on descriptives
library(parallel)  # multiple cores -- unix only

1 Objective

Predicting the survival in the Titanic disaster. We’ll be using a tidymodels approach.

2 Detect available cores

cores <- parallel::detectCores()
cores
#> [1] 8

3 Load and prepare data

3.1 Hide details in a function

prepare_data <- function(traindata_url = "https://raw.githubusercontent.com/sebastiansauer/Statistiklehre/main/data/titanic/train.csv",
                         testdata_url = "https://raw.githubusercontent.com/sebastiansauer/Statistiklehre/main/data/titanic/test.csv") {


  # import the data:
  train <- read_csv(traindata_url)
  test <- read_csv(testdata_url)
  
 
  # bind both samples into one:
  data_raw <-
    train %>% 
    bind_rows(test)
  
  
  # drop unused variables:
  data <-
    data_raw %>% 
    select(-c(Name, Cabin, Ticket))
  
  
  # convert string to factors:
  data2 <- 
    data %>% 
    mutate(across(where(is.character), as.factor))
  
  # convert numeric outcome to nominal, to indicate classification:
  data2 <- data2 %>% 
    mutate(Survived = as.factor(Survived))
  
  
  return(data2)
}
data2 <- prepare_data()

4 Split data into train and test

That’s kind of a dirty hack …

split_titanic <- initial_time_split(data = data2, 
                                    prop = 891/1309)
train <- training(split_titanic)
test <- testing(split_titanic)

5 Define recipe


titanic_recipe <- 
  
  # define model formula:
  recipe(Survived ~ ., data = train) %>%
  
  # Use "ID" etc as ID, not as predictor:
  update_role(PassengerId, new_role = "ID") %>%   

  # impute missing values:
  step_knnimpute(all_predictors(), neighbors = 3) %>%  
  
  # convert character and factor type variables into dummy variables:
  step_dummy(all_nominal(), -all_outcomes()) %>%   
  
  # exclude near zero variance predictors:
  step_nzv(all_predictors()) %>%  
  
  # exclude highly correlated vars:
  step_corr(all_predictors()) %>% 
  
  # center (set mean to zero):
  step_center(all_predictors(), -all_outcomes()) %>%  
  
  # set sd=1 
  step_scale(all_predictors(), -all_outcomes())  

6 Define model

boost_mod <- 
  boost_tree(mtry = tune(), 
             min_n = tune(), 
             learn_rate = tune(),
             tree_depth = tune()) %>% 
  set_engine("xgboost", 
             num.threads = cores) %>% 
  set_mode("classification")

Translate to xgboost() specs:

translate(boost_mod)
#> Boosted Tree Model Specification (classification)
#> 
#> Main Arguments:
#>   mtry = tune()
#>   min_n = tune()
#>   tree_depth = tune()
#>   learn_rate = tune()
#> 
#> Engine-Specific Arguments:
#>   num.threads = cores
#> 
#> Computational engine: xgboost 
#> 
#> Model fit template:
#> parsnip::xgb_train(x = missing_arg(), y = missing_arg(), colsample_bytree = tune(), 
#>     min_child_weight = tune(), max_depth = tune(), eta = tune(), 
#>     num.threads = cores, nthread = 1, verbose = 0)

7 Define cross validation scheme

train_cv <- vfold_cv(train, 
                     v = 10)

Parameters that can be tuned:

boost_mod %>%    
  parameters()  
#> Collection of 4 parameters for tuning
#> 
#>  identifier       type    object
#>        mtry       mtry nparam[?]
#>       min_n      min_n nparam[+]
#>  tree_depth tree_depth nparam[+]
#>  learn_rate learn_rate nparam[+]
#> 
#> Model parameters needing finalization:
#>    # Randomly Selected Predictors ('mtry')
#> 
#> See `?dials::finalize` or `?dials::update.parameters` for more information.

8 Define workflow

boost_wf <-
  workflow() %>% 
  add_model(boost_mod) %>% 
  add_recipe(titanic_recipe)

9 Define analysis and validation (oob) set

set.seed(42)
val_set <- validation_split(train, 
                            strata = Survived, 
                            prop = 0.80)

10 Fit the grid

set.seed(42)

t1 <- Sys.time()
boost_fit <- 
  boost_wf %>% 
  tune_grid(val_set,
            grid = 100,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc))
  t2 <- Sys.time()

t2 - t1
#> Time difference of 55.77518 secs

It may be worthwhile to save the result object to disk, in order to save computation time:

saveRDS(boost_fit, file = "boost_fit.rds")

11 View results

boost_fit %>% 
  collect_metrics()
#> # A tibble: 100 x 10
#>     mtry min_n tree_depth learn_rate .metric .estimator  mean     n std_err
#>    <int> <int>      <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl>
#>  1     3    23          2    6.70e-8 roc_auc binary     0.843     1      NA
#>  2     6    24         14    1.71e-7 roc_auc binary     0.866     1      NA
#>  3     3    26          4    7.00e-4 roc_auc binary     0.855     1      NA
#>  4     2    12          5    2.36e-6 roc_auc binary     0.854     1      NA
#>  5     3     9          9    9.25e-2 roc_auc binary     0.898     1      NA
#>  6     1    30          8    2.04e-2 roc_auc binary     0.851     1      NA
#>  7     7    16         10    1.24e-6 roc_auc binary     0.883     1      NA
#>  8     7    12         11    4.49e-5 roc_auc binary     0.888     1      NA
#>  9     6    15          4    1.91e-4 roc_auc binary     0.879     1      NA
#> 10     5    23          6    3.46e-7 roc_auc binary     0.866     1      NA
#> # … with 90 more rows, and 1 more variable: .config <chr>
boost_fit %>% 
  show_best(metric = "roc_auc")
#> # A tibble: 5 x 10
#>    mtry min_n tree_depth learn_rate .metric .estimator  mean     n std_err
#>   <int> <int>      <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl>
#> 1     5     3         10    2.80e-2 roc_auc binary     0.899     1      NA
#> 2     3     9          9    9.25e-2 roc_auc binary     0.898     1      NA
#> 3     7     2         12    8.02e-7 roc_auc binary     0.898     1      NA
#> 4     7     5         14    8.71e-5 roc_auc binary     0.895     1      NA
#> 5     5     4          9    9.33e-8 roc_auc binary     0.893     1      NA
#> # … with 1 more variable: .config <chr>
autoplot(boost_fit)

12 Get best model

boost_best <- 
  boost_fit %>% 
  select_best(metric = "roc_auc")
boost_best
#> # A tibble: 1 x 5
#>    mtry min_n tree_depth learn_rate .config               
#>   <int> <int>      <int>      <dbl> <chr>                 
#> 1     5     3         10     0.0280 Preprocessor1_Model062

13 Final fit (on train data)

boost_final_wf <- 
  boost_wf %>% 
  finalize_workflow(boost_best)

So, here is the best model:

boost_final_wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 6 Recipe Steps
#> 
#> ● step_knnimpute()
#> ● step_dummy()
#> ● step_nzv()
#> ● step_corr()
#> ● step_center()
#> ● step_scale()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (classification)
#> 
#> Main Arguments:
#>   mtry = 5
#>   min_n = 3
#>   tree_depth = 10
#>   learn_rate = 0.0280318587584122
#> 
#> Engine-Specific Arguments:
#>   num.threads = cores
#> 
#> Computational engine: xgboost

Let us fit this model to the train data:

last_boost_mod <- 
  boost_final_wf %>% 
  fit(data = train)
#> [17:09:04] WARNING: amalgamation/../src/learner.cc:516: 
#> Parameters: { num_threads } might not be used.
#> 
#>   This may not be accurate due to some parameters are only used in language bindings but
#>   passed down to XGBoost core.  Or some parameters are not used but slip through this
#>   verification. Please open an issue if you find above cases.
last_boost_mod
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 6 Recipe Steps
#> 
#> ● step_knnimpute()
#> ● step_dummy()
#> ● step_nzv()
#> ● step_corr()
#> ● step_center()
#> ● step_scale()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> ##### xgb.Booster
#> raw: 33 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.0280318587584122, max_depth = 10L, 
#>     gamma = 0, colsample_bytree = 0.625, min_child_weight = 3L, 
#>     subsample = 1), data = x$data, nrounds = 15, watchlist = x$watchlist, 
#>     verbose = 0, objective = "binary:logistic", num.threads = 8L, 
#>     nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.0280318587584122", max_depth = "10", gamma = "0", colsample_bytree = "0.625", min_child_weight = "3", subsample = "1", objective = "binary:logistic", num_threads = "8", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_error
#>        1       0.236813
#>        2       0.164983
#> ---                    
#>       14       0.130191
#>       15       0.127946

13.1 Fit final workflow (on test data)

set.seed(42)
last_boost_fit <- 
  last_boost_mod %>% 
  last_fit(split_titanic)

14 Predict test data

boost_preds <- 
last_boost_fit %>% 
  collect_predictions() %>% 
  select(-Survived) %>% 
  select(PassengerID = .row, Survived = .pred_class) 
glimpse(boost_preds)
#> Rows: 418
#> Columns: 2
#> $ PassengerID <int> 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 90…
#> $ Survived    <fct> 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, …

15 Save predictions to disk

boost_preds %>% 
  write_csv(file = "boost_preds.csv")

16 Reproducibility

#> ─ Session info ───────────────────────────────────────────────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.2 (2020-06-22)
#>  os       macOS Catalina 10.15.7      
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Europe/Berlin               
#>  date     2020-12-14                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────────────────────────────────────────────
#>  package     * version     date       lib source                             
#>  assertthat    0.2.1       2019-03-21 [1] CRAN (R 4.0.0)                     
#>  backports     1.2.0       2020-11-02 [1] CRAN (R 4.0.2)                     
#>  blogdown      0.21        2020-10-11 [1] CRAN (R 4.0.2)                     
#>  bookdown      0.21        2020-10-13 [1] CRAN (R 4.0.2)                     
#>  broom         0.7.2       2020-10-20 [1] CRAN (R 4.0.2)                     
#>  callr         3.5.1       2020-10-13 [1] CRAN (R 4.0.2)                     
#>  cellranger    1.1.0       2016-07-27 [1] CRAN (R 4.0.0)                     
#>  cli           2.2.0       2020-11-20 [1] CRAN (R 4.0.2)                     
#>  codetools     0.2-16      2018-12-24 [2] CRAN (R 4.0.2)                     
#>  colorspace    2.0-0       2020-11-11 [1] CRAN (R 4.0.2)                     
#>  crayon        1.3.4       2017-09-16 [1] CRAN (R 4.0.0)                     
#>  DBI           1.1.0       2019-12-15 [1] CRAN (R 4.0.0)                     
#>  dbplyr        2.0.0       2020-11-03 [1] CRAN (R 4.0.2)                     
#>  desc          1.2.0       2018-05-01 [1] CRAN (R 4.0.0)                     
#>  devtools      2.3.2       2020-09-18 [1] CRAN (R 4.0.2)                     
#>  digest        0.6.27      2020-10-24 [1] CRAN (R 4.0.2)                     
#>  dplyr       * 1.0.2       2020-08-18 [1] CRAN (R 4.0.2)                     
#>  ellipsis      0.3.1       2020-05-15 [1] CRAN (R 4.0.0)                     
#>  evaluate      0.14        2019-05-28 [1] CRAN (R 4.0.0)                     
#>  fansi         0.4.1       2020-01-08 [1] CRAN (R 4.0.0)                     
#>  forcats     * 0.5.0       2020-03-01 [1] CRAN (R 4.0.0)                     
#>  fs            1.5.0       2020-07-31 [1] CRAN (R 4.0.2)                     
#>  generics      0.1.0       2020-10-31 [1] CRAN (R 4.0.2)                     
#>  ggplot2     * 3.3.2       2020-06-19 [1] CRAN (R 4.0.0)                     
#>  glue          1.4.2       2020-08-27 [1] CRAN (R 4.0.2)                     
#>  gtable        0.3.0       2019-03-25 [1] CRAN (R 4.0.0)                     
#>  haven         2.3.1       2020-06-01 [1] CRAN (R 4.0.0)                     
#>  hms           0.5.3       2020-01-08 [1] CRAN (R 4.0.0)                     
#>  htmltools     0.5.0       2020-06-16 [1] CRAN (R 4.0.0)                     
#>  httr          1.4.2       2020-07-20 [1] CRAN (R 4.0.2)                     
#>  jsonlite      1.7.1       2020-09-07 [1] CRAN (R 4.0.2)                     
#>  knitr         1.30        2020-09-22 [1] CRAN (R 4.0.2)                     
#>  lifecycle     0.2.0       2020-03-06 [1] CRAN (R 4.0.0)                     
#>  lubridate     1.7.9.2     2020-11-13 [1] CRAN (R 4.0.2)                     
#>  magrittr      2.0.1       2020-11-17 [1] CRAN (R 4.0.2)                     
#>  memoise       1.1.0       2017-04-21 [1] CRAN (R 4.0.0)                     
#>  modelr        0.1.8       2020-05-19 [1] CRAN (R 4.0.0)                     
#>  munsell       0.5.0       2018-06-12 [1] CRAN (R 4.0.0)                     
#>  pillar        1.4.7       2020-11-20 [1] CRAN (R 4.0.2)                     
#>  pkgbuild      1.1.0       2020-07-13 [1] CRAN (R 4.0.2)                     
#>  pkgconfig     2.0.3       2019-09-22 [1] CRAN (R 4.0.0)                     
#>  pkgload       1.1.0       2020-05-29 [1] CRAN (R 4.0.0)                     
#>  prettyunits   1.1.1       2020-01-24 [1] CRAN (R 4.0.0)                     
#>  processx      3.4.5       2020-11-30 [1] CRAN (R 4.0.2)                     
#>  ps            1.4.0       2020-10-07 [1] CRAN (R 4.0.2)                     
#>  purrr       * 0.3.4       2020-04-17 [1] CRAN (R 4.0.0)                     
#>  R6            2.5.0       2020-10-28 [1] CRAN (R 4.0.2)                     
#>  Rcpp          1.0.5       2020-07-06 [1] CRAN (R 4.0.2)                     
#>  readr       * 1.4.0       2020-10-05 [1] CRAN (R 4.0.2)                     
#>  readxl        1.3.1       2019-03-13 [1] CRAN (R 4.0.0)                     
#>  remotes       2.2.0       2020-07-21 [1] CRAN (R 4.0.2)                     
#>  reprex        0.3.0       2019-05-16 [1] CRAN (R 4.0.0)                     
#>  rlang         0.4.9       2020-11-26 [1] CRAN (R 4.0.2)                     
#>  rmarkdown     2.5         2020-10-21 [1] CRAN (R 4.0.2)                     
#>  rprojroot     2.0.2       2020-11-15 [1] CRAN (R 4.0.2)                     
#>  rstudioapi    0.13.0-9000 2020-12-09 [1] Github (rstudio/rstudioapi@4baeb39)
#>  rvest         0.3.6       2020-07-25 [1] CRAN (R 4.0.2)                     
#>  scales        1.1.1       2020-05-11 [1] CRAN (R 4.0.0)                     
#>  sessioninfo   1.1.1       2018-11-05 [1] CRAN (R 4.0.0)                     
#>  stringi       1.5.3       2020-09-09 [1] CRAN (R 4.0.2)                     
#>  stringr     * 1.4.0       2019-02-10 [1] CRAN (R 4.0.0)                     
#>  testthat      3.0.0       2020-10-31 [1] CRAN (R 4.0.2)                     
#>  tibble      * 3.0.4       2020-10-12 [1] CRAN (R 4.0.2)                     
#>  tidyr       * 1.1.2       2020-08-27 [1] CRAN (R 4.0.2)                     
#>  tidyselect    1.1.0       2020-05-11 [1] CRAN (R 4.0.0)                     
#>  tidyverse   * 1.3.0       2019-11-21 [1] CRAN (R 4.0.0)                     
#>  usethis       1.6.3       2020-09-17 [1] CRAN (R 4.0.2)                     
#>  vctrs         0.3.5       2020-11-17 [1] CRAN (R 4.0.2)                     
#>  withr         2.3.0       2020-09-22 [1] CRAN (R 4.0.2)                     
#>  xfun          0.19        2020-10-30 [1] CRAN (R 4.0.2)                     
#>  xml2          1.3.2       2020-04-23 [1] CRAN (R 4.0.0)                     
#>  yaml          2.2.1       2020-02-01 [1] CRAN (R 4.0.0)                     
#> 
#> [1] /Users/sebastiansaueruser/Rlibs
#> [2] /Library/Frameworks/R.framework/Versions/4.0/Resources/library