class: right, top, my-title, title-slide # Decision Trees ### Jo Hardin ### October 28, 2021 --- # Agenda 10/28/21 1. Decision Trees 2. Example --- ## `tidymodels` syntax 1. partition the data 2. build a recipe 3. select a model 4. create a workflow 5. fit the model 6. validate the model --- ## Decision trees in action <div class="figure" style="text-align: center"> <img src="../images/sfnyc.png" alt="http://www.r2d3.us/visual-intro-to-machine-learning-part-1/ A visual introduction to machine learning." width="100%" /> <p class="caption">http://www.r2d3.us/visual-intro-to-machine-learning-part-1/ A visual introduction to machine learning.</p> </div> Yee and Chu created a step-by-step build of a recursive binary tree to model the differences between homes in SF and homes in NYC. http://www.r2d3.us/visual-intro-to-machine-learning-part-1/ --- ## Classification and Regression Trees (CART) **Basic Classification and Regression Trees (CART) Algorithm:** 1. Start with all observations in one group. 2. Find the variable/split that best separates the response variable (successive binary partitions based on the different predictors / explanatory variables). * Evaluation "homogeneity" within each group * Divide the data into two groups ("leaves") on that split ("node"). * Within each split, find the best variable/split that separates the outcomes. 3. Continue until the groups are too small or sufficiently "pure". 4. Prune tree. --- ## Minimize heterogeneity For every observation that falls into the region `\(R_m\)`, prediction = the mean of the response values for observations in `\(R_m\)`. `\(\Rightarrow\)` Minimize Residual Sum of Squares (RSS): `$$RSS = \sum_{m=1}^{|T|} \sum_{i \in R_m} (y_i - \overline{y}_{R_m})^2$$` where `\(\overline{y}_{R_m}\)` is the mean response for observations within the `\(m\)`th region. --- ## Recursive binary splitting Select the predictor `\(X_j\)` and the cutpoint `\(s\)` such that splitting the predictor space into the regions `\(\{X | X_j< s\}\)` and `\(\{X | X_j \geq s\}\)` lead to the greatest reduction in RSS. For any `\(j\)` and `\(s\)`, define the pair of half-planes to be `$$R_1(j,s) = \{X | X_j < s\} \mbox{ and } R_2(j,s) = \{X | X_j \geq s\}$$` Find the value of `\(j\)` and `\(s\)` that minimize the equation: `$$\sum_{i:x_i \in R_1(j,s)} (y_i - \overline{y}_{R_1})^2 + \sum_{i:x_i \in R_2(j,s)} (y_i - \overline{y}_{R_2})^2$$` where `\(\overline{y}_{R_1}\)` is the mean response for observations in `\(R_1(j,s)\)` and `\(\overline{y}_{R_2}\)` is the mean response observations in `\(R_2(j,s)\)`. --- ## Trees in action <img src="../images/decisiontrees.gif" width="100%" style="display: block; margin: auto;" /> --- ## Measures of impurity `\(\hat{p}_{mk}\)` = proportion of observations in the `\(m\)`th region from the `\(k\)`th class. * *classification error rate* = fraction of observations in the node & not in the most common class: `$$E_m = 1 - \max_k(\hat{p}_{mk})$$` * *Gini index* `$$G_m= \sum_{k=1}^K \hat{p}_{mk}(1-\hat{p}_{mk})$$` * *cross-entropy* `$$D_m = - \sum_{k=1}^K \hat{p}_{mk} \log \hat{p}_{mk}$$` (Gini index & cross-entropy will both take on a value near zero if the `\(\hat{p}_{mk}\)` values are all near zero or all near one.) --- ## Recursive binary splitting For any `\(j\)` and `\(s\)`, define the pair of half-planes to be `$$R_1(j,s) = \{X | X_j < s\} \mbox{ and } R_2(j,s) = \{X | X_j \geq s\}$$` Seek the value of `\(j\)` and `\(s\)` that minimize the equation: `\begin{align} & \sum_{i:x_i \in R_1(j,s)} \sum_{k=1}^K \hat{p}_{{R_1}k}(1-\hat{p}_{{R_1}k}) + \sum_{i:x_i \in R_2(j,s)} \sum_{k=1}^K \hat{p}_{{R_2}k}(1-\hat{p}_{{R_2}k})\\ \\ \mbox{equivalently: } & \\ & n_{R_1} \sum_{k=1}^K \hat{p}_{{R_1}k}(1-\hat{p}_{{R_1}k}) + n_{R_2} \sum_{k=1}^K \hat{p}_{{R_2}k}(1-\hat{p}_{{R_2}k})\\ \end{align}` --- ## Stopping We can always make the tree more "pure" by continuing the split. > Too many splits will overfit the model to the training data! Ways to control: * `cost_complexity` * `tree_depth` * `min_n` Overfitting: http://www.r2d3.us/visual-intro-to-machine-learning-part-2/ --- ## Cost complexity There is a cost to having a larger (more complex!) tree. Define the cost complexity criterion, `\(\alpha > 0:\)` `\begin{align} \mbox{numerical: } C_\alpha(T) &= \sum_{m=1}^{|T|} \sum_{i \in R_m} (y_i - \overline{y}_{R_m})^2 + \alpha \cdot |T|\\ \mbox{categorical: } C_\alpha(T) &= \sum_{m=1}^{|T|} \sum_{i \in R_m} I(y_i \ne k(m)) + \alpha \cdot |T| \end{align}` where `\(k(m)\)` is the class with the majority of observations in node `\(m\)` and `\(|T|\)` is the number of terminal nodes in the tree. * `\(\alpha\)` small: If `\(\alpha\)` is set to be small, we are saying that the risk is more worrisome than the complexity and larger trees are favored because they reduce the risk. * `\(\alpha\)` large: If `\(\alpha\)` is set to be large, then the complexity of the tree is more worrisome and smaller trees are favored. --- ## In practice Consider `\(\alpha\)` increasing. As `\(\alpha\)` gets bigger, the "best" tree will be smaller. The test error will not be monotonically related to the size of the training tree. <img src="../images/treealpha.jpg" width="80%" style="display: block; margin: auto;" /> --- ## A note on `\(\alpha\)` In the text (*Introduction to Statistical Learning*) and almost everywhere else you might look, the cost complexity is defined as in previous slides. However, you might notice that in R the `cost_complexity` value is typically less than 1. From what I can tell, the value of the function that is being minimized in R is the **average** of the squared errors and the missclassification **rate**. `\begin{align} \mbox{numerical: } C_\alpha(T) &= \frac{1}{n}\sum_{m=1}^{|T|} \sum_{i \in R_m} (y_i - \overline{y}_{R_m})^2 + \alpha \cdot |T|\\ \mbox{categorical: } C_\alpha(T) &= \frac{1}{n}\sum_{m=1}^{|T|} \sum_{i \in R_m} I(y_i \ne k(m)) + \alpha \cdot |T| \end{align}` --- ## CART algorithm ****** **Algorithm**: Building a Regression Tree ****** 1. Use recursive binary splitting to grow a large tree on the training data, stopping only when each terminal node has fewer than some minimum number of observations. 2. Apply cost complexity pruning to the large tree in order to obtain a sequence of best subtrees, as a function of `\(\alpha\)`. 3. Use `\(V\)`-fold cross-validation to choose `\(\alpha\)`. That is, divide the training observations into `\(V\)` folds. For each `\(v=1, 2, \ldots, V\)`: a. Repeat Steps 1 and 2 on all but the `\(v\)`th fold of the training data. b. Evaluate the mean squared prediction error on the data in the left-out `\(k\)`th fold, as a function of `\(\alpha\)`. For each value of `\(\alpha\)`, average the prediction error (either misclassification or RSS), and pick `\(\alpha\)` to minimize the average error. 4. Return the subtree from Step 2 that corresponds to the chosen value of `\(\alpha\)`. ****** --- ## CART example w defaults .panelset[ .panel[.panel-name[recipe] ```r penguin_cart_recipe <- recipe(species ~ . , data = penguin_train) %>% step_unknown(sex, new_level = "unknown") %>% step_mutate(year = as.factor(year)) summary(penguin_cart_recipe) ``` ``` ## # A tibble: 8 × 4 ## variable type role source ## <chr> <chr> <chr> <chr> ## 1 island nominal predictor original ## 2 bill_length_mm numeric predictor original ## 3 bill_depth_mm numeric predictor original ## 4 flipper_length_mm numeric predictor original ## 5 body_mass_g numeric predictor original ## 6 sex nominal predictor original ## 7 year numeric predictor original ## 8 species nominal outcome original ``` ] .panel[.panel-name[model] ```r penguin_cart <- decision_tree() %>% set_engine("rpart") %>% set_mode("classification") penguin_cart ``` ``` ## Decision Tree Model Specification (classification) ## ## Computational engine: rpart ``` ] .panel[.panel-name[workflow] ```r penguin_cart_wflow <- workflow() %>% add_model(penguin_cart) %>% add_recipe(penguin_cart_recipe) penguin_cart_wflow ``` ``` ## ══ Workflow ════════════════════════════════════════════════════════════════════ ## Preprocessor: Recipe ## Model: decision_tree() ## ## ── Preprocessor ──────────────────────────────────────────────────────────────── ## 2 Recipe Steps ## ## • step_unknown() ## • step_mutate() ## ## ── Model ─────────────────────────────────────────────────────────────────────── ## Decision Tree Model Specification (classification) ## ## Computational engine: rpart ``` ] .panel[.panel-name[fit] ```r penguin_cart_fit <- penguin_cart_wflow %>% fit(data = penguin_train) penguin_cart_fit ``` ``` ## ══ Workflow [trained] ══════════════════════════════════════════════════════════ ## Preprocessor: Recipe ## Model: decision_tree() ## ## ── Preprocessor ──────────────────────────────────────────────────────────────── ## 2 Recipe Steps ## ## • step_unknown() ## • step_mutate() ## ## ── Model ─────────────────────────────────────────────────────────────────────── ## n= 258 ## ## node), split, n, loss, yval, (yprob) ## * denotes terminal node ## ## 1) root 258 139 Adelie (0.4612403 0.1899225 0.3488372) ## 2) flipper_length_mm< 206.5 164 46 Adelie (0.7195122 0.2743902 0.0060976) ## 4) bill_length_mm< 43.15 118 3 Adelie (0.9745763 0.0254237 0.0000000) * ## 5) bill_length_mm>=43.15 46 4 Chinstrap (0.0652174 0.9130435 0.0217391) * ## 3) flipper_length_mm>=206.5 94 5 Gentoo (0.0106383 0.0425532 0.9468085) ## 6) bill_depth_mm>=17.15 7 3 Chinstrap (0.1428571 0.5714286 0.2857143) * ## 7) bill_depth_mm< 17.15 87 0 Gentoo (0.0000000 0.0000000 1.0000000) * ``` ] .panel[.panel-name[pred] ```r penguin_cart_fit %>% predict(new_data = penguin_train) %>% cbind(penguin_train) %>% select(.pred_class, species) %>% table() ``` ``` ## species ## .pred_class Adelie Chinstrap Gentoo ## Adelie 115 3 0 ## Chinstrap 4 46 3 ## Gentoo 0 0 87 ``` ] ] --- ## Plotting the tree (not tidy) .panelset[ .panel[.panel-name[plot 1] .pull-left[ ```r library(rpart.plot) penguins_cart_plot <- penguin_cart_fit %>% extract_fit_parsnip() rpart.plot( penguins_cart_plot$fit, roundint = FALSE) ``` ] .pull-right[ ![](2021-10-28-cart_files/figure-html/unnamed-chunk-13-1.png)<!-- --> ] ] .panel[.panel-name[plot 2] .pull-left[ ```r library(rattle) penguins_cart_plot <- penguin_cart_fit %>% extract_fit_parsnip() fancyRpartPlot( penguins_cart_plot$fit, sub = NULL, palettes = "RdPu") ``` ] .pull-right[ ![](2021-10-28-cart_files/figure-html/unnamed-chunk-15-1.png)<!-- --> ] ] ] --- ## CART example w CV .panelset[ .panel[.panel-name[new recipe] ```r penguin_cart_tune_recipe <- recipe(sex ~ body_mass_g + bill_length_mm + species, data = penguin_train) ``` ] .panel[.panel-name[creating folds] ```r set.seed(470) penguin_vfold <- vfold_cv(penguin_train, v = 5, strata = sex) ``` ] .panel[.panel-name[alpha] ```r cart_grid <- expand.grid( cost_complexity = c(0, 10^(seq(-5,-1,1))), tree_depth = seq(1,6, by = 1)) cart_grid ``` ``` ## cost_complexity tree_depth ## 1 0e+00 1 ## 2 1e-05 1 ## 3 1e-04 1 ## 4 1e-03 1 ## 5 1e-02 1 ## 6 1e-01 1 ## 7 0e+00 2 ## 8 1e-05 2 ## 9 1e-04 2 ## 10 1e-03 2 ## 11 1e-02 2 ## 12 1e-01 2 ## 13 0e+00 3 ## 14 1e-05 3 ## 15 1e-04 3 ## 16 1e-03 3 ## 17 1e-02 3 ## 18 1e-01 3 ## 19 0e+00 4 ## 20 1e-05 4 ## 21 1e-04 4 ## 22 1e-03 4 ## 23 1e-02 4 ## 24 1e-01 4 ## 25 0e+00 5 ## 26 1e-05 5 ## 27 1e-04 5 ## 28 1e-03 5 ## 29 1e-02 5 ## 30 1e-01 5 ## 31 0e+00 6 ## 32 1e-05 6 ## 33 1e-04 6 ## 34 1e-03 6 ## 35 1e-02 6 ## 36 1e-01 6 ``` ] .panel[.panel-name[tune wkflow] ```r penguin_cart_tune <- decision_tree(cost_complexity = tune(), tree_depth = tune()) %>% set_engine("rpart") %>% set_mode("classification") penguin_cart_wflow_tune <- workflow() %>% add_model(penguin_cart_tune) %>% add_recipe(penguin_cart_tune_recipe) ``` ] .panel[.panel-name[tuning] ```r penguin_tuned <- penguin_cart_wflow_tune %>% tune_grid(resamples = penguin_vfold, grid = cart_grid) penguin_tuned %>% collect_metrics() %>% filter(.metric == "accuracy") ``` ``` ## # A tibble: 36 × 8 ## cost_complexity tree_depth .metric .estimator mean n std_err .config ## <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 0 1 accuracy binary 0.66895 5 0.015219 Prepro… ## 2 0.00001 1 accuracy binary 0.66895 5 0.015219 Prepro… ## 3 0.0001 1 accuracy binary 0.66895 5 0.015219 Prepro… ## 4 0.001 1 accuracy binary 0.66895 5 0.015219 Prepro… ## 5 0.01 1 accuracy binary 0.66895 5 0.015219 Prepro… ## 6 0.1 1 accuracy binary 0.66895 5 0.015219 Prepro… ## 7 0 2 accuracy binary 0.69273 5 0.022573 Prepro… ## 8 0.00001 2 accuracy binary 0.69273 5 0.022573 Prepro… ## 9 0.0001 2 accuracy binary 0.69273 5 0.022573 Prepro… ## 10 0.001 2 accuracy binary 0.69273 5 0.022573 Prepro… ## # … with 26 more rows ``` ] ] --- ### Parameter choice ```r penguin_tuned %>% autoplot(metric = "accuracy") ``` <img src="2021-10-28-cart_files/figure-html/unnamed-chunk-21-1.png" style="display: block; margin: auto;" /> ```r penguin_tuned %>% select_best("accuracy") ``` ``` ## # A tibble: 1 × 3 ## cost_complexity tree_depth .config ## <dbl> <dbl> <chr> ## 1 0 5 Preprocessor1_Model25 ``` --- # Best model ```r penguin_best <- finalize_model( penguin_cart_tune, select_best(penguin_tuned, "accuracy")) workflow() %>% add_model(penguin_best) %>% add_recipe(penguin_cart_tune_recipe) %>% fit(data = penguin_train) ``` ``` ## ══ Workflow [trained] ══════════════════════════════════════════════════════════ ## Preprocessor: Recipe ## Model: decision_tree() ## ## ── Preprocessor ──────────────────────────────────────────────────────────────── ## 0 Recipe Steps ## ## ── Model ─────────────────────────────────────────────────────────────────────── ## n=248 (10 observations deleted due to missingness) ## ## node), split, n, loss, yval, (yprob) ## * denotes terminal node ## ## 1) root 248 116 female (0.532258 0.467742) ## 2) body_mass_g< 3712.5 90 15 female (0.833333 0.166667) ## 4) bill_length_mm< 38.95 44 1 female (0.977273 0.022727) * ## 5) bill_length_mm>=38.95 46 14 female (0.695652 0.304348) ## 10) body_mass_g< 3312.5 11 0 female (1.000000 0.000000) * ## 11) body_mass_g>=3312.5 35 14 female (0.600000 0.400000) ## 22) species=Chinstrap 20 5 female (0.750000 0.250000) ## 44) bill_length_mm< 48.3 11 0 female (1.000000 0.000000) * ## 45) bill_length_mm>=48.3 9 4 male (0.444444 0.555556) * ## 23) species=Adelie 15 6 male (0.400000 0.600000) * ## 3) body_mass_g>=3712.5 158 57 male (0.360759 0.639241) ## 6) species=Gentoo 86 39 female (0.546512 0.453488) ## 12) body_mass_g< 4987.5 41 1 female (0.975610 0.024390) * ## 13) body_mass_g>=4987.5 45 7 male (0.155556 0.844444) * ## 7) species=Adelie,Chinstrap 72 10 male (0.138889 0.861111) ## 14) body_mass_g< 3962.5 35 9 male (0.257143 0.742857) ## 28) bill_length_mm< 48.3 25 9 male (0.360000 0.640000) ## 56) bill_length_mm>=40.85 7 3 female (0.571429 0.428571) * ## 57) bill_length_mm< 40.85 18 5 male (0.277778 0.722222) * ## 29) bill_length_mm>=48.3 10 0 male (0.000000 1.000000) * ## 15) body_mass_g>=3962.5 37 1 male (0.027027 0.972973) * ``` --- ## Best Model Predictions ```r workflow() %>% add_model(penguin_best) %>% add_recipe(penguin_cart_tune_recipe) %>% fit(data = penguin_train) %>% predict(new_data = penguin_test) %>% cbind(penguin_test) %>% select(.pred_class, sex) %>% table() ``` ``` ## sex ## .pred_class female male ## female 26 5 ## male 7 47 ``` --- ## Best Model Predictions .pull-left[ ```r library(rpart.plot) penguins_cart_plot <- workflow() %>% add_model(penguin_best) %>% add_recipe(penguin_cart_tune_recipe) %>% fit(data = penguin_train) %>% extract_fit_parsnip() rpart.plot( penguins_cart_plot$fit, roundint = FALSE) ``` ] .pull-right[ ![](2021-10-28-cart_files/figure-html/unnamed-chunk-25-1.png)<!-- --> ] --- ## Bias-Variance Tradeoff <div class="figure" style="text-align: center"> <img src="../images/varbias.png" alt="Test and training error as a function of model complexity. Note that the error goes down monotonically only for the training data. Be careful not to overfit!! image credit: ISLR" width="90%" /> <p class="caption">Test and training error as a function of model complexity. Note that the error goes down monotonically only for the training data. Be careful not to overfit!! image credit: ISLR</p> </div> --- ## Reflecting on Model Building <div class="figure"> <img src="../images/modelbuild1.png" alt="Image credit: https://www.tmwr.org/" width="2176" /> <p class="caption">Image credit: https://www.tmwr.org/</p> </div> --- ## Reflecting on Model Building <div class="figure"> <img src="../images/modelbuild2.png" alt="Image credit: https://www.tmwr.org/" width="2067" /> <p class="caption">Image credit: https://www.tmwr.org/</p> </div> --- ## Reflecting on Model Building <div class="figure" style="text-align: center"> <img src="../images/modelbuild3.png" alt="Image credit: https://www.tmwr.org/" width="70%" /> <p class="caption">Image credit: https://www.tmwr.org/</p> </div>