Load packages and data

library(tidyverse)
## ── Attaching packages ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.0 ──
## ✓ ggplot2 3.3.2     ✓ purrr   0.3.4
## ✓ tibble  3.0.3     ✓ dplyr   1.0.2
## ✓ tidyr   1.1.1     ✓ stringr 1.4.0
## ✓ readr   1.3.1     ✓ forcats 0.5.0
## ── Conflicts ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
df <- readr::read_rds("df.rds")
step_2_b_df <- readr::read_rds("step_2_b_df.rds")

Overview

  1. Exploration
  2. Regression models
  3. Binary classification option b
  4. Binary classification option a
  5. Interpretation and optimization

This R Markdown file tackles part iii, specifically training, evaluating, tuning, and comparing models for the binary classifier outcome_2 as a function of xA, xB, response_1, x07:x11.

We will use the caret package to handle training, testing, and evaluation.

library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
## 
##     lift

General

Throughout all methods, we will use 5-fold cross validation, as our resampling method, by specifying "repeatedcv" as the method argument to a caret::trainControl(). For this classification problem, we will use the Area under the ROC curve as our primary performance metric. We must specify the summaryFunction argument to be twoClassSummary within the trainControl() function in order to maximize the area under the ROC curve. We will also instruct caret to return the class predicted probabilities.

my_ctrl <- caret::trainControl(method = "repeatedcv", 
                               number = 5, 
                               repeats = 5, 
                               savePredictions = TRUE, 
                               summaryFunction = twoClassSummary, 
                               classProbs = TRUE)
roc_metric <- "ROC"

Logistic regression with additive terms

First we will train a logistic regression model with additive terms, using method = "glm" in caret::train. We will train the model for outcome_2 as a function of xA, xB, response_1, x07:x11.

The main purpose of this logistic regression model is to provide a baseline comparison to the other complex models we will train.

set.seed(12345)
mod_glm <- caret::train(outcome_2 ~ .,
                            method = "glm", 
                            metric = roc_metric, 
                            trControl = my_ctrl,
                            preProcess = c("center", "scale"),
                            data = step_2_b_df)

mod_glm
## Generalized Linear Model 
## 
## 2013 samples
##    8 predictor
##    2 classes: 'Fail', 'Pass' 
## 
## Pre-processing: centered (10), scaled (10) 
## Resampling: Cross-Validated (5 fold, repeated 5 times) 
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ... 
## Resampling results:
## 
##   ROC        Sens       Spec     
##   0.6230073  0.5429442  0.6488326

Look at confusion matrix associated with the mod_glm model.

confusionMatrix.train(mod_glm)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Fail Pass
##       Fail 26.6 17.9
##       Pass 22.4 33.1
##                            
##  Accuracy (average) : 0.597

Regularized regression with elastic net

We now try a regularization approach. Elastic net is a mixture between Lasso and Ridge penalties. We will train two different models with interactions, specifically one with all pair interactions between all step_2_b_df input variables, and one with all triplet interactions.

All pair interactions

Let’s first fit a regularized regression model with elastic net, on all pairwise interactions between all step_2_b_df inputs, using caret::train with method="glmnet". We specify centering and scaling as preprocessing steps.

set.seed(12345)
mod_glmnet_2 <- caret::train(outcome_2 ~ (.)^2,
                             method = "glmnet",
                             preProcess = c("center", "scale"),
                             metric = roc_metric,
                             trControl = my_ctrl,
                             data = step_2_b_df)

mod_glmnet_2
## glmnet 
## 
## 2013 samples
##    8 predictor
##    2 classes: 'Fail', 'Pass' 
## 
## Pre-processing: centered (52), scaled (52) 
## Resampling: Cross-Validated (5 fold, repeated 5 times) 
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ... 
## Resampling results across tuning parameters:
## 
##   alpha  lambda        ROC        Sens       Spec     
##   0.10   0.0002009229  0.6899008  0.5602030  0.7210097
##   0.10   0.0020092294  0.6863258  0.5526904  0.7186739
##   0.10   0.0200922935  0.6856803  0.5472081  0.7260725
##   0.55   0.0002009229  0.6910454  0.5612183  0.7196486
##   0.55   0.0020092294  0.6858822  0.5512690  0.7180914
##   0.55   0.0200922935  0.6879072  0.5352284  0.7350206
##   1.00   0.0002009229  0.6919745  0.5656853  0.7194525
##   1.00   0.0020092294  0.6867074  0.5510660  0.7190632
##   1.00   0.0200922935  0.6866352  0.5181726  0.7540876
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 1 and lambda = 0.0002009229.

Create a custom tuning grid enet_grid to try out many possible values of the penalty factor (lambda) and the mixing fraction (alpha).

enet_grid <- expand.grid(alpha = seq(0.1, 0.9, by = 0.1),
                         lambda = exp(seq(-6, 0.5, length.out = 25)))

Now retrain the pairwise interactions model using tuneGrid = enet_grid.

set.seed(12345)
mod_glmnet_2_b <- caret::train(outcome_2 ~ (.)^2,
                               method = "glmnet",
                               preProcess = c("center", "scale"),
                               tuneGrid = enet_grid,
                               metric = roc_metric,
                               trControl = my_ctrl,
                               data = step_2_b_df)

mod_glmnet_2_b
## glmnet 
## 
## 2013 samples
##    8 predictor
##    2 classes: 'Fail', 'Pass' 
## 
## Pre-processing: centered (52), scaled (52) 
## Resampling: Cross-Validated (5 fold, repeated 5 times) 
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ... 
## Resampling results across tuning parameters:
## 
##   alpha  lambda       ROC        Sens         Spec     
##   0.1    0.002478752  0.6860739  0.552284264  0.7192583
##   0.1    0.003249784  0.6858461  0.551675127  0.7190651
##   0.1    0.004260651  0.6857079  0.550659898  0.7182856
##   0.1    0.005585954  0.6856006  0.551269036  0.7192565
##   0.1    0.007323502  0.6857026  0.550659898  0.7188700
##   0.1    0.009601525  0.6858698  0.550659898  0.7192583
##   0.1    0.012588142  0.6859417  0.549644670  0.7208165
##   0.1    0.016503766  0.6856930  0.548832487  0.7235387
##   0.1    0.021637371  0.6856744  0.547005076  0.7272394
##   0.1    0.028367816  0.6854550  0.544162437  0.7301577
##   0.1    0.037191811  0.6852410  0.539086294  0.7322965
##   0.1    0.048760568  0.6848695  0.536852792  0.7363827
##   0.1    0.063927861  0.6839911  0.531573604  0.7431901
##   0.1    0.083813041  0.6822212  0.522639594  0.7494132
##   0.1    0.109883635  0.6796847  0.510253807  0.7573971
##   0.1    0.144063659  0.6760792  0.493604061  0.7651774
##   0.1    0.188875603  0.6709093  0.467614213  0.7819124
##   0.1    0.247626595  0.6626101  0.428223350  0.8033152
##   0.1    0.324652467  0.6489469  0.387411168  0.8371613
##   0.1    0.425637741  0.6270342  0.332182741  0.8721695
##   0.1    0.558035146  0.6035123  0.225786802  0.9087568
##   0.1    0.731615629  0.5664170  0.120203046  0.9733545
##   0.1    0.959189457  0.5475750  0.005685279  0.9992224
##   0.1    1.257551613  0.5000000  0.000000000  1.0000000
##   0.1    1.648721271  0.5000000  0.000000000  1.0000000
##   0.2    0.002478752  0.6857670  0.551269036  0.7180933
##   0.2    0.003249784  0.6855144  0.551675127  0.7177021
##   0.2    0.004260651  0.6856550  0.551065990  0.7188681
##   0.2    0.005585954  0.6857729  0.550862944  0.7194516
##   0.2    0.007323502  0.6859659  0.550050761  0.7190651
##   0.2    0.009601525  0.6860676  0.549644670  0.7208193
##   0.2    0.012588142  0.6861723  0.548629442  0.7217864
##   0.2    0.016503766  0.6863817  0.547614213  0.7264561
##   0.2    0.021637371  0.6864606  0.543147208  0.7289870
##   0.2    0.028367816  0.6866722  0.537664975  0.7338546
##   0.2    0.037191811  0.6864140  0.532994924  0.7402785
##   0.2    0.048760568  0.6855325  0.525076142  0.7482520
##   0.2    0.063927861  0.6839517  0.515329949  0.7552583
##   0.2    0.083813041  0.6817590  0.499492386  0.7659616
##   0.2    0.109883635  0.6779782  0.477563452  0.7838541
##   0.2    0.144063659  0.6719985  0.432893401  0.8134189
##   0.2    0.188875603  0.6590478  0.368527919  0.8571868
##   0.2    0.247626595  0.6178985  0.249746193  0.9017457
##   0.2    0.324652467  0.5735423  0.176243655  0.9387184
##   0.2    0.425637741  0.5646142  0.090964467  0.9819114
##   0.2    0.558035146  0.5000000  0.000000000  1.0000000
##   0.2    0.731615629  0.5000000  0.000000000  1.0000000
##   0.2    0.959189457  0.5000000  0.000000000  1.0000000
##   0.2    1.257551613  0.5000000  0.000000000  1.0000000
##   0.2    1.648721271  0.5000000  0.000000000  1.0000000
##   0.3    0.002478752  0.6855262  0.551269036  0.7171196
##   0.3    0.003249784  0.6857567  0.551472081  0.7186758
##   0.3    0.004260651  0.6858984  0.550456853  0.7190623
##   0.3    0.005585954  0.6860252  0.549238579  0.7184816
##   0.3    0.007323502  0.6861626  0.550050761  0.7194554
##   0.3    0.009601525  0.6863108  0.547208122  0.7190670
##   0.3    0.012588142  0.6867950  0.547208122  0.7227611
##   0.3    0.016503766  0.6871967  0.545380711  0.7272337
##   0.3    0.021637371  0.6875365  0.539492386  0.7315188
##   0.3    0.028367816  0.6873415  0.533604061  0.7377457
##   0.3    0.037191811  0.6866385  0.525685279  0.7443561
##   0.3    0.048760568  0.6853485  0.515126904  0.7550604
##   0.3    0.063927861  0.6831125  0.498274112  0.7679005
##   0.3    0.083813041  0.6793538  0.472487310  0.7896888
##   0.3    0.109883635  0.6727592  0.420507614  0.8262628
##   0.3    0.144063659  0.6502934  0.315939086  0.8807265
##   0.3    0.188875603  0.5947134  0.198375635  0.9223727
##   0.3    0.247626595  0.5647183  0.156954315  0.9538944
##   0.3    0.324652467  0.5372668  0.019492386  0.9963022
##   0.3    0.425637741  0.5000000  0.000000000  1.0000000
##   0.3    0.558035146  0.5000000  0.000000000  1.0000000
##   0.3    0.731615629  0.5000000  0.000000000  1.0000000
##   0.3    0.959189457  0.5000000  0.000000000  1.0000000
##   0.3    1.257551613  0.5000000  0.000000000  1.0000000
##   0.3    1.648721271  0.5000000  0.000000000  1.0000000
##   0.4    0.002478752  0.6857773  0.551269036  0.7180924
##   0.4    0.003249784  0.6859614  0.550456853  0.7194525
##   0.4    0.004260651  0.6860992  0.548832487  0.7184807
##   0.4    0.005585954  0.6862911  0.550050761  0.7179001
##   0.4    0.007323502  0.6864629  0.548020305  0.7186787
##   0.4    0.009601525  0.6867617  0.547005076  0.7208146
##   0.4    0.012588142  0.6873807  0.545989848  0.7247066
##   0.4    0.016503766  0.6879338  0.543147208  0.7282074
##   0.4    0.021637371  0.6879212  0.537055838  0.7324925
##   0.4    0.028367816  0.6874136  0.529746193  0.7412437
##   0.4    0.037191811  0.6863198  0.519187817  0.7513635
##   0.4    0.048760568  0.6842918  0.504162437  0.7673152
##   0.4    0.063927861  0.6811679  0.481015228  0.7883239
##   0.4    0.083813041  0.6745089  0.425380711  0.8223737
##   0.4    0.109883635  0.6492606  0.305989848  0.8840322
##   0.4    0.144063659  0.5844083  0.194314721  0.9272328
##   0.4    0.188875603  0.5643279  0.159796954  0.9507781
##   0.4    0.247626595  0.5223748  0.019086294  0.9963022
##   0.4    0.324652467  0.5000000  0.000000000  1.0000000
##   0.4    0.425637741  0.5000000  0.000000000  1.0000000
##   0.4    0.558035146  0.5000000  0.000000000  1.0000000
##   0.4    0.731615629  0.5000000  0.000000000  1.0000000
##   0.4    0.959189457  0.5000000  0.000000000  1.0000000
##   0.4    1.257551613  0.5000000  0.000000000  1.0000000
##   0.4    1.648721271  0.5000000  0.000000000  1.0000000
##   0.5    0.002478752  0.6859227  0.549847716  0.7190632
##   0.5    0.003249784  0.6861373  0.549441624  0.7184798
##   0.5    0.004260651  0.6863195  0.549644670  0.7188719
##   0.5    0.005585954  0.6865409  0.550862944  0.7186777
##   0.5    0.007323502  0.6866983  0.548020305  0.7188709
##   0.5    0.009601525  0.6872111  0.546802030  0.7219825
##   0.5    0.012588142  0.6878915  0.545177665  0.7241203
##   0.5    0.016503766  0.6882423  0.540507614  0.7293763
##   0.5    0.021637371  0.6878747  0.534619289  0.7357982
##   0.5    0.028367816  0.6871196  0.525482234  0.7457229
##   0.5    0.037191811  0.6857026  0.510456853  0.7601193
##   0.5    0.048760568  0.6828236  0.492791878  0.7809339
##   0.5    0.063927861  0.6771403  0.447512690  0.8103121
##   0.5    0.083813041  0.6584815  0.338071066  0.8717765
##   0.5    0.109883635  0.5928573  0.197969543  0.9256765
##   0.5    0.144063659  0.5652270  0.170761421  0.9428037
##   0.5    0.188875603  0.5475750  0.047309645  0.9881317
##   0.5    0.247626595  0.5000000  0.000000000  1.0000000
##   0.5    0.324652467  0.5000000  0.000000000  1.0000000
##   0.5    0.425637741  0.5000000  0.000000000  1.0000000
##   0.5    0.558035146  0.5000000  0.000000000  1.0000000
##   0.5    0.731615629  0.5000000  0.000000000  1.0000000
##   0.5    0.959189457  0.5000000  0.000000000  1.0000000
##   0.5    1.257551613  0.5000000  0.000000000  1.0000000
##   0.5    1.648721271  0.5000000  0.000000000  1.0000000
##   0.6    0.002478752  0.6860662  0.549644670  0.7192574
##   0.6    0.003249784  0.6862759  0.549644670  0.7177021
##   0.6    0.004260651  0.6865044  0.550862944  0.7182894
##   0.6    0.005585954  0.6867741  0.549238579  0.7180924
##   0.6    0.007323502  0.6870432  0.547208122  0.7190642
##   0.6    0.009601525  0.6877801  0.547208122  0.7217864
##   0.6    0.012588142  0.6881916  0.545380711  0.7258735
##   0.6    0.016503766  0.6882200  0.539492386  0.7317149
##   0.6    0.021637371  0.6876132  0.529746193  0.7392991
##   0.6    0.028367816  0.6866873  0.518984772  0.7519526
##   0.6    0.037191811  0.6846122  0.502335025  0.7715974
##   0.6    0.048760568  0.6804390  0.473502538  0.7984438
##   0.6    0.063927861  0.6690011  0.394923858  0.8433786
##   0.6    0.083813041  0.6176788  0.213197970  0.9153692
##   0.6    0.109883635  0.5677658  0.182131980  0.9334615
##   0.6    0.144063659  0.5646142  0.135025381  0.9646005
##   0.6    0.188875603  0.5000000  0.000000000  1.0000000
##   0.6    0.247626595  0.5000000  0.000000000  1.0000000
##   0.6    0.324652467  0.5000000  0.000000000  1.0000000
##   0.6    0.425637741  0.5000000  0.000000000  1.0000000
##   0.6    0.558035146  0.5000000  0.000000000  1.0000000
##   0.6    0.731615629  0.5000000  0.000000000  1.0000000
##   0.6    0.959189457  0.5000000  0.000000000  1.0000000
##   0.6    1.257551613  0.5000000  0.000000000  1.0000000
##   0.6    1.648721271  0.5000000  0.000000000  1.0000000
##   0.7    0.002478752  0.6862795  0.549238579  0.7184807
##   0.7    0.003249784  0.6863470  0.550050761  0.7180942
##   0.7    0.004260651  0.6866220  0.550050761  0.7186787
##   0.7    0.005585954  0.6869401  0.548426396  0.7179010
##   0.7    0.007323502  0.6874660  0.548020305  0.7200379
##   0.7    0.009601525  0.6881534  0.546395939  0.7227601
##   0.7    0.012588142  0.6884292  0.542944162  0.7278210
##   0.7    0.016503766  0.6881562  0.536649746  0.7330751
##   0.7    0.021637371  0.6874909  0.525076142  0.7439697
##   0.7    0.028367816  0.6861433  0.513908629  0.7589534
##   0.7    0.037191811  0.6830915  0.493197970  0.7832688
##   0.7    0.048760568  0.6762752  0.440609137  0.8147819
##   0.7    0.063927861  0.6435505  0.288121827  0.8848061
##   0.7    0.083813041  0.5809687  0.190253807  0.9291793
##   0.7    0.109883635  0.5640952  0.167106599  0.9451366
##   0.7    0.144063659  0.5140434  0.019289340  0.9955226
##   0.7    0.188875603  0.5000000  0.000000000  1.0000000
##   0.7    0.247626595  0.5000000  0.000000000  1.0000000
##   0.7    0.324652467  0.5000000  0.000000000  1.0000000
##   0.7    0.425637741  0.5000000  0.000000000  1.0000000
##   0.7    0.558035146  0.5000000  0.000000000  1.0000000
##   0.7    0.731615629  0.5000000  0.000000000  1.0000000
##   0.7    0.959189457  0.5000000  0.000000000  1.0000000
##   0.7    1.257551613  0.5000000  0.000000000  1.0000000
##   0.7    1.648721271  0.5000000  0.000000000  1.0000000
##   0.8    0.002478752  0.6863824  0.550050761  0.7182865
##   0.8    0.003249784  0.6865248  0.550659898  0.7173166
##   0.8    0.004260651  0.6868808  0.550050761  0.7180942
##   0.8    0.005585954  0.6871951  0.548223350  0.7169254
##   0.8    0.007323502  0.6878127  0.548020305  0.7196495
##   0.8    0.009601525  0.6884802  0.546802030  0.7227611
##   0.8    0.012588142  0.6884157  0.541928934  0.7293791
##   0.8    0.016503766  0.6880747  0.534010152  0.7363817
##   0.8    0.021637371  0.6871249  0.522030457  0.7490296
##   0.8    0.028367816  0.6852552  0.505786802  0.7673142
##   0.8    0.037191811  0.6811506  0.474923858  0.7941634
##   0.8    0.048760568  0.6680624  0.393299492  0.8425982
##   0.8    0.063927861  0.6078368  0.208121827  0.9200350
##   0.8    0.083813041  0.5664792  0.182741117  0.9330713
##   0.8    0.109883635  0.5646142  0.131167513  0.9657665
##   0.8    0.144063659  0.5000000  0.000000000  1.0000000
##   0.8    0.188875603  0.5000000  0.000000000  1.0000000
##   0.8    0.247626595  0.5000000  0.000000000  1.0000000
##   0.8    0.324652467  0.5000000  0.000000000  1.0000000
##   0.8    0.425637741  0.5000000  0.000000000  1.0000000
##   0.8    0.558035146  0.5000000  0.000000000  1.0000000
##   0.8    0.731615629  0.5000000  0.000000000  1.0000000
##   0.8    0.959189457  0.5000000  0.000000000  1.0000000
##   0.8    1.257551613  0.5000000  0.000000000  1.0000000
##   0.8    1.648721271  0.5000000  0.000000000  1.0000000
##   0.9    0.002478752  0.6865801  0.550050761  0.7180924
##   0.9    0.003249784  0.6868156  0.551065990  0.7177059
##   0.9    0.004260651  0.6870774  0.549847716  0.7169264
##   0.9    0.005585954  0.6875774  0.548832487  0.7177059
##   0.9    0.007323502  0.6881523  0.546598985  0.7200388
##   0.9    0.009601525  0.6885486  0.545989848  0.7243211
##   0.9    0.012588142  0.6883489  0.540507614  0.7319091
##   0.9    0.016503766  0.6879464  0.531167513  0.7410504
##   0.9    0.021637371  0.6866768  0.517360406  0.7535060
##   0.9    0.028367816  0.6839671  0.500304569  0.7770429
##   0.9    0.037191811  0.6777720  0.452182741  0.8071930
##   0.9    0.048760568  0.6490495  0.302741117  0.8772190
##   0.9    0.063927861  0.5831544  0.190862944  0.9287900
##   0.9    0.083813041  0.5641494  0.173604061  0.9408563
##   0.9    0.109883635  0.5255309  0.028832487  0.9918257
##   0.9    0.144063659  0.5000000  0.000000000  1.0000000
##   0.9    0.188875603  0.5000000  0.000000000  1.0000000
##   0.9    0.247626595  0.5000000  0.000000000  1.0000000
##   0.9    0.324652467  0.5000000  0.000000000  1.0000000
##   0.9    0.425637741  0.5000000  0.000000000  1.0000000
##   0.9    0.558035146  0.5000000  0.000000000  1.0000000
##   0.9    0.731615629  0.5000000  0.000000000  1.0000000
##   0.9    0.959189457  0.5000000  0.000000000  1.0000000
##   0.9    1.257551613  0.5000000  0.000000000  1.0000000
##   0.9    1.648721271  0.5000000  0.000000000  1.0000000
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0.9 and lambda = 0.009601525.

Print out the non-zero coefficients, specifying the optimal value of lambda identified by resampling.

coef(mod_glmnet_2_b$finalModel, s = mod_glmnet_2_b$bestTune$lambda) %>% 
  as.matrix() %>% 
  as.data.frame() %>% 
  tibble::rownames_to_column("coef_name") %>% 
  tibble::as_tibble() %>% 
  purrr::set_names(c("coef_name", "coef_value")) %>%
  filter(coef_value != 0)
## # A tibble: 18 x 2
##    coef_name       coef_value
##    <chr>                <dbl>
##  1 (Intercept)        0.0262 
##  2 response_1        -0.112  
##  3 x09                0.00277
##  4 x10               -0.0393 
##  5 xAA2:xBB3          0.0210 
##  6 xBB3:response_1    0.718  
##  7 xBB4:response_1   -0.0894 
##  8 xBB4:x07          -0.102  
##  9 xBB4:x08          -0.00938
## 10 xBB2:x10           0.0319 
## 11 xBB4:x10          -0.162  
## 12 response_1:x07    -0.125  
## 13 response_1:x08    -0.323  
## 14 response_1:x10    -0.117  
## 15 response_1:x11    -0.0366 
## 16 x07:x09            0.166  
## 17 x08:x09            0.0229 
## 18 x10:x11           -0.0179

Visualize trends of metric AUC with respect to mixing percentage alpha and regularization parameter lambda, for model trained with our defined enet_grid.

plot(mod_glmnet_2_b, xTrans = log)

All triplet interactions

Now fit a regularized regression model with elastic net, on all triplet interactions between all step_2_b_df inputs, using tuneGrid = enet_grid, then displaying the optimal tuning parameters.

set.seed(12345)
mod_glmnet_3_b <- caret::train(outcome_2 ~ (.)^3,
                               method = "glmnet",
                               preProcess = c("center", "scale"),
                               tuneGrid = enet_grid,
                               metric = roc_metric,
                               trControl = my_ctrl,
                               data = step_2_b_df)

mod_glmnet_3_b$bestTune
##     alpha     lambda
## 207   0.9 0.01258814

Check number of coefficients:

# number of coefficients
mod_glmnet_3_b$coefnames %>% length()
## [1] 150
# check
(model.matrix(outcome_2 ~ (.)^3, data = step_2_b_df) %>% colnames() %>% length() - 1) - (mod_glmnet_3_b$coefnames %>% length())
## [1] 0

Visualize trends of metric AUC with respect to mixing percentage alpha and regularization parameter lambda, for model trained with our defined enet_grid.

plot(mod_glmnet_3_b, xTrans = log)

Compare

Compare resampling results across the two different models.

glmnet_results <- resamples(list(glmnet_2way = mod_glmnet_2_b,
                                 glmnet_3way = mod_glmnet_3_b))

dotplot(glmnet_results)

glmnet_3way seems to be the better model.

mod_glmnet_3_b$results %>% filter(alpha == mod_glmnet_3_b$bestTune$alpha & lambda == mod_glmnet_3_b$bestTune$lambda)
##   alpha     lambda       ROC      Sens      Spec      ROCSD     SensSD
## 1   0.9 0.01258814 0.6880523 0.5417259 0.7301577 0.02726367 0.03867319
##       SpecSD
## 1 0.03300155

Check confusionMatrix for the better model.

confusionMatrix.train(mod_glmnet_3_b)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Fail Pass
##       Fail 26.5 13.8
##       Pass 22.4 37.3
##                            
##  Accuracy (average) : 0.638

Partial least squares

As described in Dr Yurko’s Ionosphere Caret Demo, “partial least squares (PLS) models are particularly well suited when the inputs are highly correlated to each other”. Although our EDA did not reveal any particularly interesting correlations between inputs like there are in the Ionosphere dataset, we can still try PLS to see how well the model performs for the step_2_b_df inputs.

pls_grid <- expand.grid(ncomp = seq(1, 5, by = 1))

set.seed(12345)
mod_pls <- caret::train(outcome_2 ~ ., 
                        method = "pls",
                        preProcess = c("center", "scale"),
                        tuneGrid = pls_grid,
                        metric = roc_metric,
                        trControl = my_ctrl,
                        data = step_2_b_df)

plot(mod_pls)

Check confusion matrix.

confusionMatrix.train(mod_pls)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Fail Pass
##       Fail 26.5 17.8
##       Pass 22.4 33.2
##                             
##  Accuracy (average) : 0.5973

Neural network

Now we will try several more complex, non-linear methods (which can capture non-linear relationships between inputs).

Single hidden layer

Fit a neural network regression model, specified by the default setting linout = FALSE (which we don’t have to explicitly set). First use the default tuning parameter search grid.

set.seed(12345)
mod_nnet <- caret::train(outcome_2 ~ .,
                         method = "nnet",
                         preProcess = c("center", "scale"),
                         trace = FALSE,
                         metric = roc_metric,
                         trControl = my_ctrl,
                         data = step_2_b_df)

mod_nnet
## Neural Network 
## 
## 2013 samples
##    8 predictor
##    2 classes: 'Fail', 'Pass' 
## 
## Pre-processing: centered (10), scaled (10) 
## Resampling: Cross-Validated (5 fold, repeated 5 times) 
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ... 
## Resampling results across tuning parameters:
## 
##   size  decay  ROC        Sens       Spec     
##   1     0e+00  0.6095927  0.3638579  0.8519119
##   1     1e-04  0.6139867  0.3571574  0.8603239
##   1     1e-01  0.6162856  0.3880203  0.8340260
##   3     0e+00  0.7303479  0.6393909  0.7192906
##   3     1e-04  0.7347013  0.6316751  0.7404253
##   3     1e-01  0.7701533  0.6864975  0.7330637
##   5     0e+00  0.7550836  0.6830457  0.7079763
##   5     1e-04  0.7555098  0.6775635  0.7124111
##   5     1e-01  0.7959942  0.7179695  0.7426038
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were size = 5 and decay = 0.1.

Optimal tuning parameters found by caret’s method="nnet" are 5 hidden units (single layer), 0.1 decay.

Now using a more refined search grid over the tuning parameters.

nnet_grid <- expand.grid(size = c(2, 4, 6, 8, 10, 12),
                         decay = c(1e-4, 0.1, 0.5))

Fitting the nnet model using this more refined search grid. The grid search might take a few minutes.

set.seed(12345)
mod_nnet_b <- caret::train(outcome_2 ~ .,
                           method = "nnet",
                           tuneGrid = nnet_grid,
                           preProcess = c("center", "scale"),
                           trace = FALSE,
                           metric = roc_metric,
                           trControl = my_ctrl,
                           data = step_2_b_df)

mod_nnet_b
## Neural Network 
## 
## 2013 samples
##    8 predictor
##    2 classes: 'Fail', 'Pass' 
## 
## Pre-processing: centered (10), scaled (10) 
## Resampling: Cross-Validated (5 fold, repeated 5 times) 
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ... 
## Resampling results across tuning parameters:
## 
##   size  decay  ROC        Sens       Spec     
##    2    1e-04  0.6947106  0.6150254  0.6879394
##    2    1e-01  0.7218871  0.6475127  0.7031077
##    2    5e-01  0.7277722  0.6450761  0.7075757
##    4    1e-04  0.7439569  0.6678173  0.7054861
##    4    1e-01  0.7879585  0.7053807  0.7390973
##    4    5e-01  0.7863310  0.6944162  0.7439517
##    6    1e-04  0.7546195  0.6751269  0.7146067
##    6    1e-01  0.7940649  0.7076142  0.7478617
##    6    5e-01  0.7977233  0.7059898  0.7476751
##    8    1e-04  0.7540962  0.6871066  0.7079678
##    8    1e-01  0.7910635  0.7059898  0.7423955
##    8    5e-01  0.7990779  0.7047716  0.7552555
##   10    1e-04  0.7470351  0.6777665  0.6993966
##   10    1e-01  0.7923378  0.7147208  0.7414303
##   10    5e-01  0.7959266  0.7065990  0.7494123
##   12    1e-04  0.7444634  0.6747208  0.6922216
##   12    1e-01  0.7850273  0.7135025  0.7361904
##   12    5e-01  0.7909739  0.7078173  0.7418215
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were size = 8 and decay = 0.5.

It is observed that the neural network model favors 8 hidden units and weight decay 0.5.

Plot AUC against number of hidden units, colored distinctly by weight decay.

plot(mod_nnet_b)

Check confusion matrix based on cross-validation results.

confusionMatrix.train(mod_nnet_b)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Fail Pass
##       Fail 34.5 12.5
##       Pass 14.4 38.6
##                             
##  Accuracy (average) : 0.7306

Highest Accuracy so far!

Random forest

Random forests have become a handy and convenient learning algorithm that has good predictive performance with “relatively little hyperparameter tuning”. We will use method = "rf" that allows us to use caret::train as we have for all other models. By default, the random forest model creates 500 bagged tree models. The random forest model randomly selects, at each split, mtry features to consider for the splitting process.

We use a custom grid for different mtry values. Because we have 8 predictors, we will try mtry = seq(2, 8, by = 1). The code chunk below might take a few minutes to run to completion.

rf_grid <- expand.grid(mtry = seq(2, 8, by = 1))

set.seed(12345)
mod_rf <- caret::train(outcome_2 ~ .,
                       method = "rf",
                       importance = TRUE,
                       tuneGrid = rf_grid,
                       trControl = my_ctrl,
                       metric = roc_metric,
                       data = step_2_b_df)

mod_rf
## Random Forest 
## 
## 2013 samples
##    8 predictor
##    2 classes: 'Fail', 'Pass' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 5 times) 
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ... 
## Resampling results across tuning parameters:
## 
##   mtry  ROC        Sens       Spec     
##   2     0.7935896  0.7181726  0.7278125
##   3     0.8004950  0.7390863  0.7173014
##   4     0.8007297  0.7472081  0.7141814
##   5     0.7989170  0.7478173  0.7167104
##   6     0.7989985  0.7449746  0.7157433
##   7     0.7970867  0.7425381  0.7132077
##   8     0.7969455  0.7409137  0.7182704
## 
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 4.

Check confusion matrix based on cross-validation results.

confusionMatrix.train(mod_rf)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Fail Pass
##       Fail 36.6 14.6
##       Pass 12.4 36.5
##                             
##  Accuracy (average) : 0.7304

Gradient boosted tree

Gradient boosting machines (GBM) build shallow trees in sequence, with each tree “learning and improving on the previous one”; as opposed to random forests which build deep independent trees. When gradient boosted and tuned, these shallow trees collectively form one of the best predictive models.

Set method = "xgbTree" in caret::train.

set.seed(12345)
mod_gbm <- caret::train(outcome_2 ~ .,
                        method = "xgbTree",
                        verbose = FALSE,
                        metric = roc_metric,
                        trControl = my_ctrl,
                        data = step_2_b_df)

mod_gbm
## eXtreme Gradient Boosting 
## 
## 2013 samples
##    8 predictor
##    2 classes: 'Fail', 'Pass' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold, repeated 5 times) 
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ... 
## Resampling results across tuning parameters:
## 
##   eta  max_depth  colsample_bytree  subsample  nrounds  ROC        Sens     
##   0.3  1          0.6               0.50        50      0.7779795  0.6489340
##   0.3  1          0.6               0.50       100      0.7770389  0.6696447
##   0.3  1          0.6               0.50       150      0.7752452  0.6682234
##   0.3  1          0.6               0.75        50      0.7786550  0.6432487
##   0.3  1          0.6               0.75       100      0.7798483  0.6598985
##   0.3  1          0.6               0.75       150      0.7785176  0.6674112
##   0.3  1          0.6               1.00        50      0.7786337  0.6337056
##   0.3  1          0.6               1.00       100      0.7808250  0.6511675
##   0.3  1          0.6               1.00       150      0.7801575  0.6578680
##   0.3  1          0.8               0.50        50      0.7766658  0.6554315
##   0.3  1          0.8               0.50       100      0.7776187  0.6684264
##   0.3  1          0.8               0.50       150      0.7730245  0.6647716
##   0.3  1          0.8               0.75        50      0.7793421  0.6507614
##   0.3  1          0.8               0.75       100      0.7812941  0.6676142
##   0.3  1          0.8               0.75       150      0.7795003  0.6700508
##   0.3  1          0.8               1.00        50      0.7790067  0.6324873
##   0.3  1          0.8               1.00       100      0.7812061  0.6521827
##   0.3  1          0.8               1.00       150      0.7805229  0.6558376
##   0.3  2          0.6               0.50        50      0.7827854  0.6783756
##   0.3  2          0.6               0.50       100      0.7789977  0.6785787
##   0.3  2          0.6               0.50       150      0.7781678  0.6871066
##   0.3  2          0.6               0.75        50      0.7910640  0.6877157
##   0.3  2          0.6               0.75       100      0.7882538  0.6907614
##   0.3  2          0.6               0.75       150      0.7840055  0.6901523
##   0.3  2          0.6               1.00        50      0.7929283  0.6846701
##   0.3  2          0.6               1.00       100      0.7914720  0.6834518
##   0.3  2          0.6               1.00       150      0.7908281  0.6871066
##   0.3  2          0.8               0.50        50      0.7914532  0.6956345
##   0.3  2          0.8               0.50       100      0.7867046  0.6919797
##   0.3  2          0.8               0.50       150      0.7822556  0.6864975
##   0.3  2          0.8               0.75        50      0.7957215  0.6913706
##   0.3  2          0.8               0.75       100      0.7903248  0.6867005
##   0.3  2          0.8               0.75       150      0.7862135  0.6893401
##   0.3  2          0.8               1.00        50      0.7983836  0.6911675
##   0.3  2          0.8               1.00       100      0.7957468  0.6929949
##   0.3  2          0.8               1.00       150      0.7922797  0.6897462
##   0.3  3          0.6               0.50        50      0.7825930  0.6895431
##   0.3  3          0.6               0.50       100      0.7757822  0.6938071
##   0.3  3          0.6               0.50       150      0.7680179  0.6834518
##   0.3  3          0.6               0.75        50      0.7933062  0.7003046
##   0.3  3          0.6               0.75       100      0.7869022  0.7039594
##   0.3  3          0.6               0.75       150      0.7782612  0.6917766
##   0.3  3          0.6               1.00        50      0.7951543  0.6927919
##   0.3  3          0.6               1.00       100      0.7893617  0.6948223
##   0.3  3          0.6               1.00       150      0.7832788  0.6954315
##   0.3  3          0.8               0.50        50      0.7867280  0.6927919
##   0.3  3          0.8               0.50       100      0.7786522  0.6964467
##   0.3  3          0.8               0.50       150      0.7724253  0.6913706
##   0.3  3          0.8               0.75        50      0.7925116  0.7023350
##   0.3  3          0.8               0.75       100      0.7830211  0.7005076
##   0.3  3          0.8               0.75       150      0.7757153  0.6899492
##   0.3  3          0.8               1.00        50      0.7962037  0.6988832
##   0.3  3          0.8               1.00       100      0.7913282  0.6996954
##   0.3  3          0.8               1.00       150      0.7845493  0.6976650
##   0.4  1          0.6               0.50        50      0.7729907  0.6590863
##   0.4  1          0.6               0.50       100      0.7740366  0.6633503
##   0.4  1          0.6               0.50       150      0.7708851  0.6684264
##   0.4  1          0.6               0.75        50      0.7786480  0.6584772
##   0.4  1          0.6               0.75       100      0.7771192  0.6672081
##   0.4  1          0.6               0.75       150      0.7752912  0.6651777
##   0.4  1          0.6               1.00        50      0.7796603  0.6489340
##   0.4  1          0.6               1.00       100      0.7799885  0.6586802
##   0.4  1          0.6               1.00       150      0.7783323  0.6578680
##   0.4  1          0.8               0.50        50      0.7745897  0.6609137
##   0.4  1          0.8               0.50       100      0.7712045  0.6716751
##   0.4  1          0.8               0.50       150      0.7705070  0.6653807
##   0.4  1          0.8               0.75        50      0.7769092  0.6609137
##   0.4  1          0.8               0.75       100      0.7775787  0.6686294
##   0.4  1          0.8               0.75       150      0.7757975  0.6676142
##   0.4  1          0.8               1.00        50      0.7790819  0.6485279
##   0.4  1          0.8               1.00       100      0.7798347  0.6558376
##   0.4  1          0.8               1.00       150      0.7779126  0.6570558
##   0.4  2          0.6               0.50        50      0.7790184  0.6883249
##   0.4  2          0.6               0.50       100      0.7764061  0.6840609
##   0.4  2          0.6               0.50       150      0.7772341  0.6885279
##   0.4  2          0.6               0.75        50      0.7876783  0.6881218
##   0.4  2          0.6               0.75       100      0.7818489  0.6875127
##   0.4  2          0.6               0.75       150      0.7782556  0.6838579
##   0.4  2          0.6               1.00        50      0.7883856  0.6862944
##   0.4  2          0.6               1.00       100      0.7860620  0.6814213
##   0.4  2          0.6               1.00       150      0.7814446  0.6824365
##   0.4  2          0.8               0.50        50      0.7821316  0.6891371
##   0.4  2          0.8               0.50       100      0.7796126  0.6927919
##   0.4  2          0.8               0.50       150      0.7732114  0.6875127
##   0.4  2          0.8               0.75        50      0.7925472  0.6903553
##   0.4  2          0.8               0.75       100      0.7886042  0.6942132
##   0.4  2          0.8               0.75       150      0.7829557  0.6891371
##   0.4  2          0.8               1.00        50      0.7964976  0.6925888
##   0.4  2          0.8               1.00       100      0.7906264  0.6871066
##   0.4  2          0.8               1.00       150      0.7869048  0.6864975
##   0.4  3          0.6               0.50        50      0.7704341  0.6830457
##   0.4  3          0.6               0.50       100      0.7607032  0.6755330
##   0.4  3          0.6               0.50       150      0.7551643  0.6706599
##   0.4  3          0.6               0.75        50      0.7858409  0.6956345
##   0.4  3          0.6               0.75       100      0.7744090  0.6879188
##   0.4  3          0.6               0.75       150      0.7698534  0.6848731
##   0.4  3          0.6               1.00        50      0.7862716  0.6905584
##   0.4  3          0.6               1.00       100      0.7816847  0.6899492
##   0.4  3          0.6               1.00       150      0.7756945  0.6842640
##   0.4  3          0.8               0.50        50      0.7765383  0.6905584
##   0.4  3          0.8               0.50       100      0.7693592  0.6905584
##   0.4  3          0.8               0.50       150      0.7627729  0.6875127
##   0.4  3          0.8               0.75        50      0.7849647  0.6958376
##   0.4  3          0.8               0.75       100      0.7747129  0.6897462
##   0.4  3          0.8               0.75       150      0.7693298  0.6862944
##   0.4  3          0.8               1.00        50      0.7921649  0.7005076
##   0.4  3          0.8               1.00       100      0.7830085  0.6964467
##   0.4  3          0.8               1.00       150      0.7772760  0.6893401
##   Spec     
##   0.7706275
##   0.7626474
##   0.7564187
##   0.7729623
##   0.7661520
##   0.7608894
##   0.7723808
##   0.7710168
##   0.7725664
##   0.7610940
##   0.7575932
##   0.7531063
##   0.7644035
##   0.7634355
##   0.7603012
##   0.7731632
##   0.7708226
##   0.7723741
##   0.7501823
##   0.7468766
##   0.7427999
##   0.7556448
##   0.7461122
##   0.7482463
##   0.7719896
##   0.7682737
##   0.7655638
##   0.7550585
##   0.7494037
##   0.7466929
##   0.7597092
##   0.7581653
##   0.7544655
##   0.7708065
##   0.7665318
##   0.7634165
##   0.7505944
##   0.7324935
##   0.7247028
##   0.7575847
##   0.7462884
##   0.7328695
##   0.7591362
##   0.7538859
##   0.7462913
##   0.7419986
##   0.7404329
##   0.7279877
##   0.7507668
##   0.7420194
##   0.7396789
##   0.7571868
##   0.7552394
##   0.7480417
##   0.7571906
##   0.7589486
##   0.7540914
##   0.7628473
##   0.7632261
##   0.7604954
##   0.7657665
##   0.7712100
##   0.7677026
##   0.7564234
##   0.7523448
##   0.7544845
##   0.7614795
##   0.7605030
##   0.7591513
##   0.7659569
##   0.7706275
##   0.7676997
##   0.7465025
##   0.7381302
##   0.7414293
##   0.7591456
##   0.7474610
##   0.7451300
##   0.7577713
##   0.7587582
##   0.7472849
##   0.7548567
##   0.7451309
##   0.7387175
##   0.7620829
##   0.7506019
##   0.7422335
##   0.7673029
##   0.7641970
##   0.7567956
##   0.7344343
##   0.7268596
##   0.7176926
##   0.7492058
##   0.7338309
##   0.7254691
##   0.7470888
##   0.7416197
##   0.7353853
##   0.7435785
##   0.7252721
##   0.7211954
##   0.7461075
##   0.7410523
##   0.7293886
##   0.7562273
##   0.7494132
##   0.7361847
## 
## Tuning parameter 'gamma' was held constant at a value of 0
## Tuning
##  parameter 'min_child_weight' was held constant at a value of 1
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 50, max_depth = 2, eta
##  = 0.3, gamma = 0, colsample_bytree = 0.8, min_child_weight = 1 and subsample
##  = 1.

The best model identified has 50 iterations (nrounds), complexity (max_depth) of 2, learning rate (eta) of 0.3, and minimum number of training set samples in a node to commence sampling (subsample) of 1.

Check confusion matrix.

confusionMatrix.train(mod_gbm)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Fail Pass
##       Fail 33.8 11.7
##       Pass 15.1 39.4
##                             
##  Accuracy (average) : 0.7318

Support Vector Machine

The motivation for fitting a Support Vector Machine (SVM) is that SVMs have several advantages compared to other methods, as mentioned in the Hands-On Machine Learning with R book:

The basic idea of SVMs is dividing classes through hyperplanes; using a “kernel trick”, as Dr Yurko puts it, to transform from the original space to a new feature space, on which it then tries to create linear separating boundaries between the classes.

First, load kernlab library.

library(kernlab)
## 
## Attaching package: 'kernlab'
## The following object is masked from 'package:purrr':
## 
##     cross
## The following object is masked from 'package:ggplot2':
## 
##     alpha

We will stick to the general rule of thumb to use a radial basis kernel in our caret::train call, using method="svmRadial".

First see what are the parameters to be learned:

caret::getModelInfo("svmRadial")$svmRadial$parameters
##   parameter   class label
## 1     sigma numeric Sigma
## 2         C numeric  Cost

Now fit the model.

set.seed(12345)
mod_svm <- caret::train(outcome_2 ~ .,
                        method = "svmRadial",
                        preProcess = c("center", "scale"),
                        metric = roc_metric,
                        trControl = my_ctrl,
                        data = step_2_b_df)

mod_svm
## Support Vector Machines with Radial Basis Function Kernel 
## 
## 2013 samples
##    8 predictor
##    2 classes: 'Fail', 'Pass' 
## 
## Pre-processing: centered (10), scaled (10) 
## Resampling: Cross-Validated (5 fold, repeated 5 times) 
## Summary of sample sizes: 1610, 1611, 1611, 1610, 1610, 1611, ... 
## Resampling results across tuning parameters:
## 
##   C     ROC        Sens       Spec     
##   0.25  0.7593338  0.6501523  0.7348245
##   0.50  0.7699159  0.6605076  0.7441639
##   1.00  0.7760909  0.6665990  0.7509751
## 
## Tuning parameter 'sigma' was held constant at a value of 0.07027841
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were sigma = 0.07027841 and C = 1.

Plot the results to see cross-validated ROC scores against different cost values.

plot(mod_svm)

Use a refined custom grid search, based on the identified best sigma.

svm_grid <- expand.grid(sigma = mod_svm$bestTune$sigma * c(0.25, 0.5, 1, 2),
                        C = c(0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0))

set.seed(12345)
mod_svm_b <- caret::train(outcome_2 ~ .,
                          method = "svmRadial",
                          preProcess = c("center", "scale"),
                          tuneGrid = svm_grid,
                          metric = roc_metric,
                          trControl = my_ctrl,
                          data = step_2_b_df)

mod_svm_b$bestTune
##       sigma  C
## 6 0.0175696 16

Plot results.

ggplot(mod_svm_b) + theme_bw()

Clearly, the model corresponding to the red line, sigma = 0.01756960 is the best model with the highest AUC at Cost = 16.0.

Check confusion matrix.

confusionMatrix.train(mod_svm_b)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Fail Pass
##       Fail 30.9 11.1
##       Pass 18.1 39.9
##                             
##  Accuracy (average) : 0.7081

Multivariate Adaptive Regression Splines

The motivation for fitting a Multivariate Adaptive Regression Splines (MARS) model is to explore more nonlinear relationships between the inputs. MARS is capable of extending linear models to capture multiple nonlinear relationships by searching for and discovering nonlinearities and interactions in the data that will help maximize predictive accuracy.

First, load earth library for MARS modeling.

library(earth)
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix
## Loading required package: TeachingDemos

Hands-On Maching Learning in R describes the inner workings of MARS. Instead of explicitly defining polynomial functions or natural spline functions ourselves, MARS provides a convenient approach to capture the nonlinear relationships in the data by assessing cutpoints, like step functions. The procedure assesses each data point for each input as a knot and creates a linear regression model with the candidate feature(s).

To help in the tuning of this procedure, we can specify tuning parameters such as the maximum degree of interactions, degree, and the number of terms retained in the final model, nprune, in a tuning grid to be passed into the caret::train call. Since there is rarely any benefit in assessing greater than triplet interactions, we choose degree = 1:3. We also start out with 10 evenly spaced values and intend to zoom in when we later find an approximate optimal solution and there is cause to.

mars_grid <- expand.grid(degree = 1:3,
                         nprune = seq(2, 100, length.out = 10) %>% floor())

mars_grid %>% head()
##   degree nprune
## 1      1      2
## 2      2      2
## 3      3      2
## 4      1     12
## 5      2     12
## 6      3     12

We will use caret::train, as in the previous sections. The grid search might take a few minutes.

set.seed(12345)
mod_mars <- caret::train(outcome_2 ~ .,
                         method = "earth",
                         tuneGrid = mars_grid,
                         metric = roc_metric,
                         trControl = my_ctrl,
                         data = step_2_b_df)

mod_mars$bestTune
##    nprune degree
## 13     23      2

Plot the model.

ggplot(mod_mars)

Here, because the optimal ROC values stay constant beyond roughly nprune = 23 terms, there is no need to adjust to a more specific tuning grid.

mod_mars$resample %>% summary()
##       ROC              Sens             Spec          Resample        
##  Min.   :0.7679   Min.   :0.6345   Min.   :0.7039   Length:25         
##  1st Qu.:0.7803   1st Qu.:0.6802   1st Qu.:0.7379   Class :character  
##  Median :0.7971   Median :0.6954   Median :0.7524   Mode  :character  
##  Mean   :0.7972   Mean   :0.6991   Mean   :0.7576                     
##  3rd Qu.:0.8106   3rd Qu.:0.7208   3rd Qu.:0.7756                     
##  Max.   :0.8283   Max.   :0.7513   Max.   :0.8301

Check confusion matrix.

confusionMatrix.train(mod_mars)
## Cross-Validated (5 fold, repeated 5 times) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction Fail Pass
##       Fail 34.2 12.4
##       Pass 14.7 38.7
##                            
##  Accuracy (average) : 0.729

Comparing methods

ROC comparison

Now that we have fit all of the models, we can compare the cross-validation hold-out set performance metrics. We first compile all of the “resample” results together with the resamples() function.

iii_results <- resamples(list(glm = mod_glm,
                              glmnet_2way = mod_glmnet_2_b,
                              glmnet_3way = mod_glmnet_3_b,
                              nnet = mod_nnet_b,
                              rf = mod_rf,
                              xgb = mod_gbm,
                              svm = mod_svm_b,
                              mars = mod_mars,
                              pls = mod_pls))

Then we visually compare the performance metrics.

dotplot(iii_results)

dotplot(iii_results, metric = "ROC")

dotplot(iii_results, metric = "Sens")

dotplot(iii_results, metric = "Spec")

Based on AUC, rf is the best model; although rf, nnet, xgb, mars seem to be close to each other in terms of performance. While rf does the best in terms of AUC and Sens, it does not fare so well in Spec, while mars and xgb are pretty consistent across all three metrics.

Assemble the ROC curves for comparison. First, identify the best tuned model and combine the cross-validation hold-out set predictions.

cv_pred_results <- mod_glm$pred %>% tbl_df() %>% 
  filter(parameter == mod_glm$bestTune$parameter) %>% 
  select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
  mutate(model_name = "glm") %>% 
  bind_rows(mod_glmnet_2_b$pred %>% tbl_df() %>% 
              filter(alpha %in% mod_glmnet_2_b$bestTune$alpha,
                     lambda %in% mod_glmnet_2_b$bestTune$lambda) %>% 
              select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
              mutate(model_name = "glmnet_2_b")) %>% 
  bind_rows(mod_glmnet_3_b$pred %>% tbl_df() %>% 
              filter(alpha %in% mod_glmnet_3_b$bestTune$alpha,
                     lambda %in% mod_glmnet_3_b$bestTune$lambda) %>% 
              select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
              mutate(model_name = "glmnet_3_b")) %>% 
  bind_rows(mod_pls$pred %>% tbl_df() %>% 
              filter(ncomp %in% mod_pls$bestTune$ncomp) %>% 
              select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
              mutate(model_name = "pls")) %>% 
  bind_rows(mod_nnet_b$pred %>% tbl_df() %>% 
              filter(size == mod_nnet_b$bestTune$size,
                     decay == mod_nnet_b$bestTune$decay) %>% 
              select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
              mutate(model_name = "nnet")) %>% 
  bind_rows(mod_rf$pred %>% tbl_df() %>% 
              filter(mtry == mod_rf$bestTune$mtry) %>% 
              select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
              mutate(model_name = "rf")) %>% 
  bind_rows(mod_gbm$pred %>% tbl_df() %>% 
              filter(nrounds == mod_gbm$bestTune$nrounds,
                     max_depth == mod_gbm$bestTune$max_depth,
                     eta %in% mod_gbm$bestTune$eta,
                     gamma %in% mod_gbm$bestTune$gamma,
                     colsample_bytree %in% mod_gbm$bestTune$colsample_bytree,
                     min_child_weight == mod_gbm$bestTune$min_child_weight,
                     subsample == mod_gbm$bestTune$subsample) %>% 
              select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
              mutate(model_name = "xgb")) %>% 
  bind_rows(mod_svm_b$pred %>% tbl_df() %>% 
              filter(sigma %in% mod_svm_b$bestTune$sigma,
                     C %in% mod_svm_b$bestTune$C) %>% 
              select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
              mutate(model_name = "svm")) %>% 
  bind_rows(mod_mars$pred %>% tbl_df() %>% 
              filter(nprune == mod_mars$bestTune$nprune,
                     degree == mod_mars$bestTune$degree) %>% 
              select(pred, obs, Fail, Pass, rowIndex, Resample) %>% 
              mutate(model_name = "mars"))
## Warning: `tbl_df()` is deprecated as of dplyr 1.0.0.
## Please use `tibble::as_tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
### nrounds = 50, max_depth = 2, eta = 0.3, gamma = 0, 
### colsample_bytree = 0.8, min_child_weight = 1 and subsample = 1.

Load plotROC to plot ROC curves.

library(plotROC)

Visualize the ROC curves for each fold-resample broken up by the methods.

cv_pred_results %>% 
  ggplot(mapping = aes(m = Fail,
                       d = ifelse(obs == "Fail",
                                  1, 
                                  0))) +
  geom_roc(cutoffs.at = 0.5,
           mapping = aes(color = Resample)) +
  geom_roc(cutoffs.at = 0.5) +
  coord_equal() +
  facet_wrap(~model_name) +
  style_roc()

The black line is the ROC curve averaged over all folds and repeats. Examine rf model more closely since it seems to be the best performing.

cv_pred_results %>% 
  filter(model_name == "rf") %>% 
  ggplot(mapping = aes(m = Fail,
                       d = ifelse(obs == "Fail",
                                  1, 
                                  0))) +
  geom_roc(cutoffs.at = 0.5,
           mapping = aes(color = Resample)) +
  geom_roc(cutoffs.at = 0.5) +
  coord_equal() +
  facet_wrap(~model_name) +
  style_roc()

Compare cross-validation averaged ROC curves.

cv_pred_results %>% 
  ggplot(mapping = aes(m = Fail,
                       d = ifelse(obs == "Fail",
                                  1, 
                                  0),
                       color = model_name)) +
  geom_roc(cutoffs.at = 0.5) +
  coord_equal() +
  style_roc() +
  ggthemes::scale_color_calc()

As we expected, rf, nnet, xgb and mars all perform comparatively well.

Consider the calibration curves associated with the cross-validation hold-out sets for the above four models, and a linear model glmnet.

rf_test_pred_good <- mod_rf$pred %>% tbl_df() %>% 
              filter(mtry == mod_rf$bestTune$mtry) %>% 
              select(obs, Fail, rowIndex, Resample)

nnet_test_pred_good <- mod_nnet_b$pred %>% tbl_df() %>% 
              filter(size == mod_nnet_b$bestTune$size,
                     decay == mod_nnet_b$bestTune$decay) %>% 
              select(obs, Fail, rowIndex, Resample)

xgb_test_pred_good <- mod_gbm$pred %>% tbl_df() %>% 
              filter(nrounds == mod_gbm$bestTune$nrounds,
                     max_depth == mod_gbm$bestTune$max_depth,
                     eta %in% mod_gbm$bestTune$eta,
                     gamma %in% mod_gbm$bestTune$gamma,
                     colsample_bytree %in% mod_gbm$bestTune$colsample_bytree,
                     min_child_weight == mod_gbm$bestTune$min_child_weight,
                     subsample == mod_gbm$bestTune$subsample) %>% 
              select(obs, Fail, rowIndex, Resample)

mars_test_pred_good <- mod_mars$pred %>% tbl_df() %>% 
              filter(nprune == mod_mars$bestTune$nprune,
                     degree == mod_mars$bestTune$degree) %>% 
              select(obs, Fail, rowIndex, Resample)

glmnet_3_b_test_pred_good <- mod_glmnet_3_b$pred %>% tbl_df() %>% 
              filter(alpha %in% mod_glmnet_3_b$bestTune$alpha,
                     lambda %in% mod_glmnet_3_b$bestTune$lambda) %>% 
              select(obs, Fail, rowIndex, Resample)

cal_holdout_preds <- rf_test_pred_good %>% rename(rf = Fail) %>% 
  left_join(nnet_test_pred_good %>% rename(nnet = Fail),
            by = c("obs", "rowIndex", "Resample")) %>% 
  left_join(xgb_test_pred_good %>% rename(xgb = Fail),
            by = c("obs", "rowIndex", "Resample")) %>% 
  left_join(mars_test_pred_good %>% rename(mars = Fail),
            by = c("obs", "rowIndex", "Resample")) %>% 
  left_join(glmnet_3_b_test_pred_good %>% rename(glmnet = Fail),
            by = c("obs", "rowIndex", "Resample")) %>% 
  select(outcome_2 = obs, rf, nnet, xgb, mars, glmnet)

Generate calibration curves.

cal_object <- calibration(outcome_2 ~ rf + nnet + xgb + mars + glmnet,
                          data = cal_holdout_preds,
                          cuts = 10)

ggplot(cal_object) + theme_bw() + theme(legend.position = "top")

glmnet seems to be well calibrated mostly, except at the lower predicted probability around 30-40%.

cal_object <- calibration(outcome_2 ~ rf + nnet + xgb + mars + glmnet,
                          data = cal_holdout_preds,
                          cuts = 5)

ggplot(cal_object) + theme_bw() + theme(legend.position = "top")

These calibration curves show that the linear model with triplet interactions glmnet_3_b is well-calibrated, although its point-wise predictive accuracy metrics were lower than the non-linear models.

Accuracy comparison

Based on Accuracy, the result for best model appear to be slightly different, although the previously identified four best models are still the same.

calc_accuracy <- function(model) {
  cf <- confusionMatrix.train(model)
  
  return( (cf$table[1,1] + cf$table[2,2]) / 100 )
}

models <- list(glm = mod_glm, glmnet_2way = mod_glmnet_2_b, glmnet_3way = mod_glmnet_3_b, nnet = mod_nnet_b, rf = mod_rf, xgb = mod_gbm, pls = mod_pls, svm = mod_svm, mars = mod_mars)

accuracy_results <- purrr::map_dbl(models, calc_accuracy)

accuracy_results %>% sort(decreasing = TRUE)
##         xgb        nnet          rf        mars         svm glmnet_3way 
##   0.7318430   0.7305514   0.7303527   0.7289617   0.7096870   0.6379533 
## glmnet_2way         pls         glm 
##   0.6370591   0.5973174   0.5970194

xgb seems to be the best performing model in terms of Accuracy.

Variable importance rankings

Complex non-linear models can be difficult to interpret. We can consider ranking the relative importance of the input variables in the step_2_b_df dataset. Plot variable importance based on rf model.

plot(varImp(mod_rf))

Plot variable importance based on xgb model.

plot(varImp(mod_gbm))

x07, x08 and response_1 seem to be the three most important inputs.

Saving

# Rename models list item names
names(models) <- c("iii_glm", "iii_glmnet_2way", "iii_glmnet_3way", "iii_nnet", "iii_rf", "iii_xgb", "iii_pls", "iii_svm", "iii_mars")
# Save models list
models %>% readr::write_rds("iii_models.rds")