\(k\)-Nearest Neighbors

October 23, 2024

Jo Hardin

Agenda 10/23/24

  1. Redux - model process
  2. cross validation
  3. \(k\)-Nearest Neighbors

tidymodels syntax

  1. partition the data
  2. build a recipe
  3. select a model
  4. create a workflow
  5. fit the model
  6. validate the model

All together

penguin_lm_recipe <-
  recipe(body_mass_g ~ species + island + bill_length_mm + 
           bill_depth_mm + flipper_length_mm + sex + year,
         data = penguin_train) |>
  step_mutate(year = as.factor(year)) |>
  step_unknown(sex, new_level = "unknown") |>
  step_relevel(sex, ref_level = "female") |>
  update_role(island, new_role = "id variable")

penguin_lm_recipe
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs 
Number of variables by role
outcome:     1
predictor:   6
id variable: 1
── Operations 
• Variable mutation for: as.factor(year)
• Unknown factor level assignment for: sex
• Re-order factor level to ref_level for: sex
penguin_lm <- linear_reg() |>
  set_engine("lm")

penguin_lm
Linear Regression Model Specification (regression)

Computational engine: lm 
penguin_lm_wflow <- workflow() |>
  add_model(penguin_lm) |>
  add_recipe(penguin_lm_recipe)

penguin_lm_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_mutate()
• step_unknown()
• step_relevel()

── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Computational engine: lm 
penguin_lm_fit <- penguin_lm_wflow |>
  fit(data = penguin_train)

penguin_lm_fit |> tidy()
# A tibble: 10 × 5
   term              estimate std.error statistic  p.value
   <chr>                <dbl>     <dbl>     <dbl>    <dbl>
 1 (Intercept)        -2417.     665.      -3.64  3.36e- 4
 2 speciesChinstrap    -208.      92.9     -2.24  2.58e- 2
 3 speciesGentoo        985.     152.       6.48  5.02e-10
 4 bill_length_mm        13.5      8.29     1.63  1.04e- 1
 5 bill_depth_mm         80.9     22.1      3.66  3.10e- 4
 6 flipper_length_mm     20.8      3.62     5.74  2.81e- 8
 7 sexmale              351.      52.6      6.67  1.72e-10
 8 sexunknown            47.6    103.       0.460 6.46e- 1
 9 year2008             -24.8     47.5     -0.521 6.03e- 1
10 year2009             -61.9     46.0     -1.35  1.80e- 1

model parameters

  • Some model parameters are tuned from the data (some aren’t).

    • linear model coefficients are optimized (not tuned)
    • \(k\)-nn value of \(k\) is tuned
  • If the model is tuned using the data, the same data cannot be used to assess the model.

  • With Cross Validation, you iteratively put data in your pocket.

  • For example, keep 1/5 of the data in your pocket, build the model on the remaining 4/5 of the data.

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

Cross validation

for tuning parameters

Image credit: Alison Hill

model parameters

  • Some model parameters are tuned from the data (some aren’t).

    • linear model coefficients are optimized (not tuned)
    • \(k\)-NN value of \(k\) is tuned
  • If the model is tuned using the data, the same data cannot be used to assess the model.

  • With Cross Validation, you iteratively put data in your pocket.

  • For example, keep 1/5 of the data in your pocket, build the model on the remaining 4/5 of the data.

\(k\)-Nearest Neighbors

The \(k\)-Nearest Neighbor algorithm does exactly what it sounds like it does.

  • user decides on the integer value for \(k\)

  • user decides on a distance metric (most \(k\)-NN algorithms default to Euclidean distance)

  • a point is classified to be in the same group as the majority of the \(k\) closest points in the training data.

\(k\)-NN visually

Consider a population, a training set, and a decision boundary:

image credit: Ricardo Gutierrez-Osuna

\(k\)-NN visually

Choosing \(k\) accurately is one of the most important aspects of the algorithm.

image credit: Ricardo Gutierrez-Osuna

\(k\)-NN to predict penguin species

penguin_knn_recipe <-
  recipe(species ~ body_mass_g + island + bill_length_mm + 
           bill_depth_mm + flipper_length_mm,
         data = penguin_train) |>
  update_role(island, new_role = "id variable") |>
  step_normalize(all_predictors())

summary(penguin_knn_recipe)
# A tibble: 6 × 4
  variable          type      role        source  
  <chr>             <list>    <chr>       <chr>   
1 body_mass_g       <chr [2]> predictor   original
2 island            <chr [3]> id variable original
3 bill_length_mm    <chr [2]> predictor   original
4 bill_depth_mm     <chr [2]> predictor   original
5 flipper_length_mm <chr [2]> predictor   original
6 species           <chr [3]> outcome     original
penguin_knn <- nearest_neighbor() |>
  set_engine("kknn") |>
  set_mode("classification")

penguin_knn
K-Nearest Neighbor Model Specification (classification)

Computational engine: kknn 
penguin_knn_wflow <- workflow() |>
  add_model(penguin_knn) |>
  add_recipe(penguin_knn_recipe)

penguin_knn_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()

── Preprocessor ────────────────────────────────────────────────────────────────
1 Recipe Step

• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
K-Nearest Neighbor Model Specification (classification)

Computational engine: kknn 
penguin_knn_fit <- penguin_knn_wflow |>
  fit(data = penguin_train)
penguin_knn_fit |> 
  predict(new_data = penguin_test) |>
  cbind(penguin_test) |>
  metrics(truth = species, estimate = .pred_class) |>
  filter(.metric == "accuracy")
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy multiclass     0.988

what is \(k\) ???

It turns out that the default value for \(k\) in the kknn engine is 7. Is 7 best?

Cross Validation!!!

The red observations are used to fit the model, the black observations are used to assess the model.

Image credit: Alison Hill

Cross validation

Randomly split the training data into V distinct blocks of roughly equal size.

  • leave out the first block of analysis data and fit a model.
  • the model is used to predict the held-out block of assessment data.
  • continue the process until all V assessment blocks have been predicted.

The tuned parameter is usually chosen to be the one which produces the best performance averaged across the V blocks.

The final performance is usually based on the test data.

Extending the modeling process

set.seed(470)
penguin_vfold <- vfold_cv(penguin_train,
                          v = 3, strata = species)
k_grid <- data.frame(neighbors = seq(1, 15, by = 4))
k_grid
  neighbors
1         1
2         5
3         9
4        13
penguin_knn_tune <- nearest_neighbor(neighbors = tune()) |>
  set_engine("kknn") |>
  set_mode("classification")

penguin_knn_wflow_tune <- workflow() |>
  add_model(penguin_knn_tune) |>
  add_recipe(penguin_knn_recipe)
penguin_knn_wflow_tune |>
  tune_grid(resamples = penguin_vfold, 
           grid = k_grid) |>
  collect_metrics() |>
  filter(.metric == "accuracy")
# A tibble: 4 × 7
  neighbors .metric  .estimator  mean     n   std_err .config             
      <dbl> <chr>    <chr>      <dbl> <int>     <dbl> <chr>               
1         1 accuracy multiclass 0.971     2 0.00595   Preprocessor1_Model1
2         5 accuracy multiclass 0.977     2 0.000134  Preprocessor1_Model2
3         9 accuracy multiclass 0.988     2 0.0000668 Preprocessor1_Model3
4        13 accuracy multiclass 0.983     2 0.00568   Preprocessor1_Model4

We choose \(k\) = 9 !

6. Validate the model

penguin_knn_recipe <-
  recipe(species ~ body_mass_g + island + bill_length_mm + 
           bill_depth_mm + flipper_length_mm,
         data = penguin_train) |>
  update_role(island, new_role = "id variable") |>
  step_normalize(all_predictors())

summary(penguin_knn_recipe)
# A tibble: 6 × 4
  variable          type      role        source  
  <chr>             <list>    <chr>       <chr>   
1 body_mass_g       <chr [2]> predictor   original
2 island            <chr [3]> id variable original
3 bill_length_mm    <chr [2]> predictor   original
4 bill_depth_mm     <chr [2]> predictor   original
5 flipper_length_mm <chr [2]> predictor   original
6 species           <chr [3]> outcome     original
penguin_knn_final <- nearest_neighbor(neighbors = 9) |>
  set_engine("kknn") |>
  set_mode("classification")

penguin_knn_final
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = 9

Computational engine: kknn 
penguin_knn_wflow_final <- workflow() |>
  add_model(penguin_knn_final) |>
  add_recipe(penguin_knn_recipe)

penguin_knn_wflow_final
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()

── Preprocessor ────────────────────────────────────────────────────────────────
1 Recipe Step

• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = 9

Computational engine: kknn 
penguin_knn_fit_final <- penguin_knn_wflow_final |>
  fit(data = penguin_train)
penguin_knn_fit_final |> 
  predict(new_data = penguin_test) |>
  cbind(penguin_test) |>
  metrics(truth = species, estimate = .pred_class) |>
  filter(.metric == "accuracy")
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy multiclass     0.977

We choose \(k\) = 9 !

6. Validate the model

Huh. Seems like \(k=9\) didn’t do as well as \(k=7\) (the value we tried at the very beginning before cross validating).

Well, it turns out, that’s the nature of variability, randomness, and model building.

We don’t know truth, and we won’t every find a perfect model.

Bias-Variance Tradeoff

Test and training error as a function of model complexity. Note that the error goes down monotonically only for the training data. Be careful not to overfit!! image credit: ISLR

Reflecting on Model Building

Image credit: https://www.tmwr.org/

Reflecting on Model Building

Image credit: https://www.tmwr.org/

Reflecting on Model Building

Image credit: https://www.tmwr.org/