library(tidyverse)
library(tidymodels)
theme_set(theme_minimal())
Predicting Coffee Quality With ML
1 Coffee Quality Ratings
Coffee can be a huge productivity boost and I can’t imagine working through the day without it. This data set, originally uploaded from Coffee Quality Database, was re-posted to Kaggle and subsequently featured on TidyTuesday.
In it, coffee from different countries are awarded cup points, scored by panelists who sample the coffee and assess it based on a number of factors such as aroma, acidity, uniformity and sweetness. But do other factors, such as country, altitude and processing method, also affect coffee quality scores?
This blog post will set out to investigate the data with exploratory data analysis. Next, utilizing the tidymodels
collection, I will create new predictors with feature engineering, and subsequently specify, tune, compare in-sample results based on RMSE for three popular machine learning models (LASSO, random forest and XGBoost). Variable importance of features will also be compared. Finally, the model that’s able to deliver the lowest out-of-sample RMSE when predicting coffee quality points will be selected.
2 Import libaries & data
<- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-07-07/coffee_ratings.csv')
coffee
dim(coffee)
[1] 1339 43
While not big data, this data set has over 1.3k rows and 43 columns! This is the data dictionary, and full descriptions and examples can be found here.
3 Exploratory Data Analysis
3.1 Distribution of total_cup_points
Let’s first get a sense of the distribution of total_cup_points
, which are the rating points given to each cup of coffee on a scale of 0-100.
%>%
coffee ggplot(aes(total_cup_points)) +
geom_histogram(binwidth = 1, fill = "#00AFBB", color="#e9ecef", alpha=0.6) +
labs(
x = "Total_cup_points",
y = "Count",
title = "Analyzing Distribution of Coffee Quality Points: Histogram",
subtitle = "Majority of scores are clustered between 80-90,\nwith some significant outliers - potentially data errors",
caption = "Source: Coffee Quality Database"
+
) theme(
plot.title = element_text(face = "bold", size = 20),
plot.subtitle = element_text(size = 17)
)
%>%
coffee arrange(total_cup_points) %>%
head()
# A tibble: 6 × 43
total_cup_points species owner country_of_origin farm_name lot_number mill
<dbl> <chr> <chr> <chr> <chr> <chr> <chr>
1 0 Arabica bismarc… Honduras los hica… 103 cigr…
2 59.8 Arabica juan lu… Guatemala finca el… <NA> bene…
3 63.1 Arabica exporta… Nicaragua finca la… 017-053-0… bene…
4 67.9 Arabica myriam … Haiti 200 farms <NA> coeb…
5 68.3 Arabica juan ca… Mexico el cente… <NA> la e…
6 69.2 Arabica cadexsa Honduras cerro bu… <NA> cade…
# ℹ 36 more variables: ico_number <chr>, company <chr>, altitude <chr>,
# region <chr>, producer <chr>, number_of_bags <dbl>, bag_weight <chr>,
# in_country_partner <chr>, harvest_year <chr>, grading_date <chr>,
# owner_1 <chr>, variety <chr>, processing_method <chr>, aroma <dbl>,
# flavor <dbl>, aftertaste <dbl>, acidity <dbl>, body <dbl>, balance <dbl>,
# uniformity <dbl>, clean_cup <dbl>, sweetness <dbl>, cupper_points <dbl>,
# moisture <dbl>, category_one_defects <dbl>, quakers <dbl>, color <chr>, …
There’s one outlier with zero total_cup_points
- probably a data entry error. Let’s remove that. At the same time, there does not appear to be a unique identifier for each coffee, so let’s add that.
<- coffee %>%
coffee filter(total_cup_points > 0) %>%
mutate(id = row_number()) %>%
select(id, everything())
coffee
# A tibble: 1,338 × 44
id total_cup_points species owner country_of_origin farm_name lot_number
<int> <dbl> <chr> <chr> <chr> <chr> <chr>
1 1 90.6 Arabica metad … Ethiopia "metad p… <NA>
2 2 89.9 Arabica metad … Ethiopia "metad p… <NA>
3 3 89.8 Arabica ground… Guatemala "san mar… <NA>
4 4 89 Arabica yidnek… Ethiopia "yidneka… <NA>
5 5 88.8 Arabica metad … Ethiopia "metad p… <NA>
6 6 88.8 Arabica ji-ae … Brazil <NA> <NA>
7 7 88.8 Arabica hugo v… Peru <NA> <NA>
8 8 88.7 Arabica ethiop… Ethiopia "aolme" <NA>
9 9 88.4 Arabica ethiop… Ethiopia "aolme" <NA>
10 10 88.2 Arabica diamon… Ethiopia "tulla c… <NA>
# ℹ 1,328 more rows
# ℹ 37 more variables: mill <chr>, ico_number <chr>, company <chr>,
# altitude <chr>, region <chr>, producer <chr>, number_of_bags <dbl>,
# bag_weight <chr>, in_country_partner <chr>, harvest_year <chr>,
# grading_date <chr>, owner_1 <chr>, variety <chr>, processing_method <chr>,
# aroma <dbl>, flavor <dbl>, aftertaste <dbl>, acidity <dbl>, body <dbl>,
# balance <dbl>, uniformity <dbl>, clean_cup <dbl>, sweetness <dbl>, …
3.2 Investigating missingness
How much missing data is in this data set?
%>%
coffee ::skim() %>%
skimrselect(skim_variable, complete_rate) %>%
arrange(complete_rate)
# A tibble: 44 × 2
skim_variable complete_rate
<chr> <dbl>
1 lot_number 0.206
2 farm_name 0.732
3 mill 0.765
4 producer 0.827
5 altitude_low_meters 0.828
6 altitude_high_meters 0.828
7 altitude_mean_meters 0.828
8 altitude 0.831
9 variety 0.831
10 color 0.837
# ℹ 34 more rows
Most of the columns have more than 80% completeness. I’ll filter for columns with more than 70% completeness, and map()
a count()
across all columns to let me further investigate columns that could be of interest.
In my EDA I do this for all columns but, for the sake of brevity, I’ll only select
a few columns to illustrate the output.
%>%
coffee select(owner_1:processing_method) %>%
map(~ count(data.frame(x = .x), x, sort = TRUE)) %>%
map(~ head(., n = 10))
$owner_1
x n
1 Juan Luis Alvarado Romero 155
2 Racafe & Cia S.C.A 60
3 Exportadora de Cafe Condor S.A 54
4 Kona Pacific Farmers Cooperative 52
5 Ipanema Coffees 50
6 CQI Taiwan ICP CQI台灣合作夥伴 46
7 Lin, Che-Hao Krude 林哲豪 29
8 NUCOFFEE 29
9 CARCAFE LTDA CI 27
10 The Coffee Source Inc. 23
$variety
x n
1 Caturra 255
2 Bourbon 226
3 <NA> 226
4 Typica 211
5 Other 110
6 Catuai 74
7 Hawaiian Kona 44
8 Yellow Bourbon 35
9 Mundo Novo 33
10 Catimor 20
$processing_method
x n
1 Washed / Wet 815
2 Natural / Dry 258
3 <NA> 169
4 Semi-washed / Semi-pulped 56
5 Other 26
6 Pulped natural / honey 14
#what I actually do
<- coffee %>%
cols ::skim() %>%
skimrselect(skim_variable, complete_rate) %>%
arrange(complete_rate) %>%
filter(complete_rate > 0.7) %>%
pull(skim_variable)
#what I actually do
%>%
coffee select(cols) %>%
map( ~ count(data.frame(x = .x), x, sort = TRUE)) %>%
map( ~ head(., n = 10))
Let’s dig further into the data before finalizing the columns.
3.3 How are total_cup_points calculated?
From this page in the Coffee Institute’s data base, it appears that total_cup_points
is the sum of columns aroma
to cupper_points
.
%>%
coffee select(total_cup_points, aroma:cupper_points)
# A tibble: 1,338 × 11
total_cup_points aroma flavor aftertaste acidity body balance uniformity
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 90.6 8.67 8.83 8.67 8.75 8.5 8.42 10
2 89.9 8.75 8.67 8.5 8.58 8.42 8.42 10
3 89.8 8.42 8.5 8.42 8.42 8.33 8.42 10
4 89 8.17 8.58 8.42 8.42 8.5 8.25 10
5 88.8 8.25 8.5 8.25 8.5 8.42 8.33 10
6 88.8 8.58 8.42 8.42 8.5 8.25 8.33 10
7 88.8 8.42 8.5 8.33 8.5 8.25 8.25 10
8 88.7 8.25 8.33 8.5 8.42 8.33 8.5 10
9 88.4 8.67 8.67 8.58 8.42 8.33 8.42 9.33
10 88.2 8.08 8.58 8.5 8.5 7.67 8.42 10
# ℹ 1,328 more rows
# ℹ 3 more variables: clean_cup <dbl>, sweetness <dbl>, cupper_points <dbl>
The code below verifies just that.
set.seed(123)
%>%
coffee group_by(id) %>%
mutate(sum = sum(across(aroma:cupper_points))) %>%
select(id, total_cup_points, sum) %>%
ungroup() %>%
slice_sample(n = 5)
# A tibble: 5 × 3
id total_cup_points sum
<int> <dbl> <dbl>
1 415 83.3 83.3
2 463 83.2 83.2
3 179 84.4 84.4
4 526 83 83.0
5 195 84.2 84.2
3.4 Correlation of total_cup_points
library(GGally)
%>%
coffee select(total_cup_points, aroma:cupper_points) %>%
ggcorr()
As expected, total_cup_points
showcases positive correlation with all of underlying scores, but less so for uniformity, clean_cup and sweetness.
3.5 Replacing missingness in altitude data
From some googling,
High altitudes are considered ideal for growing the coffee plant, with cooler temperatures delaying the growth cycle. This allows the bean to go through a longer maturation process, thus creating a much fuller, richer, and more pronounced flavour.
If that is indeed the case, we can expect altitude to be a significant predictor. However, upon inspecting the completeness of the altitude-related columns, around 20% of the data is missing.
%>%
coffee ::skim() %>%
skimrselect(skim_variable, complete_rate) %>%
filter(skim_variable %in% c("altitude", "altitude_low_meters", "altitude_high_meters", "altitude_mean_meters"))
# A tibble: 4 × 2
skim_variable complete_rate
<chr> <dbl>
1 altitude 0.831
2 altitude_low_meters 0.828
3 altitude_high_meters 0.828
4 altitude_mean_meters 0.828
Before trying to replace these missing values, let’s ask some questions:
3.5.1 Relationship between altitude and altitude_mean_meters?
set.seed(123)
%>%
coffee select(altitude, altitude_mean_meters) %>%
slice_sample(n = 20)
# A tibble: 20 × 2
altitude altitude_mean_meters
<chr> <dbl>
1 <NA> NA
2 1800 1800
3 <NA> NA
4 1600 - 1950 msnm 1775
5 1500 1500
6 1400 msnm 1400
7 1250 1250
8 750m 750
9 4300 1311.
10 <NA> NA
11 <NA> NA
12 1750 msnm 1750
13 1200 1200
14 934 934
15 800++ 800
16 1100 1100
17 1130 1130
18 de 1.600 a 1.950 msnm 1775
19 1100 1100
20 1 1
It looks like altitude_mean_meters
is a clean version of altitude
- so I’ll focus on using `altitude_mean_meters
for now.
3.5.2 Missingness of other altitude columns
When altitude_mean_meters
is missing, are altitude_low_meters
and altitude_high_meters
missing too?
%>%
coffee filter(is.na(altitude_mean_meters)) %>%
select(contains("meters")) %>%
#checking for NAs
summarise(not_na =
sum(!is.na(
across(everything())
)))
# A tibble: 1 × 1
not_na
<int>
1 0
Yes, we can expect altitude_low_meters
and altitude_high_meters
to show missing values when altitude_mean_meters
is missing. I was hoping to use the former two columns to replace missing values in altitude_mean_meters
.
3.5.3 Standardizing altitude measurements to meters
After converting all altitude measurements made in foot to meters, are there any inconsistencies? (1 ft = 0.3048 meters)
<- coffee %>%
outlier mutate(meters = case_when(
str_detect(unit_of_measurement, "ft") ~ altitude_mean_meters * 0.3048,
TRUE ~ altitude_mean_meters),
country_of_origin = fct_lump(country_of_origin, 4)) %>%
filter(!is.na(meters)) %>%
filter(meters > 8000) %>%
pull(id)
library(fishualize)
library(ggforce)
%>%
coffee mutate(
meters = case_when(
str_detect(unit_of_measurement, "ft") ~ altitude_mean_meters * 0.3048,
TRUE ~ altitude_mean_meters
),country_of_origin = fct_lump(country_of_origin, 4)
%>%
) filter(!is.na(country_of_origin)) %>%
ggplot(aes(total_cup_points, meters)) +
geom_point(aes(colour = country_of_origin),
size = 1.5, alpha = 0.9) +
geom_mark_ellipse(aes(
filter = id %in% outlier)) +
scale_colour_fish_d(option = "Etheostoma_spectabile") +
scale_y_log10(labels = comma) +
labs(
x = "Total_cup_points",
y = "Meters (log scale)",
colour = "Country of Origin",
title = "Plotting Altitude (meters) against Coffee Quality Points",
subtitle = "Outliers are circled and are likely data entry errors",
caption = "Source: Coffee Quality Database"
+
) theme(plot.title = element_text(face = "bold", size = 20),
plot.subtitle = element_text(size = 17))
As a sanity check, the highest mountain in the world is Mount Everest with its peak at 8,848 meters. Clearly there are some data entry errors recording over 100,000 meters in altitude. Let’s exclude any altitude records above 8000m.
<- coffee %>%
coffee mutate(meters = case_when(
str_detect(unit_of_measurement, "ft") ~ altitude_mean_meters * 0.3048,
TRUE ~ altitude_mean_meters
%>%
)) #explicitly keep NAs because missing values will be replaced later
filter(is.na(meters) | meters <= 8000) %>%
select(-altitude, -altitude_low_meters, -altitude_high_meters, -altitude_mean_meters)
3.5.4 Replacing NAs
Now that all altitude measurements are standardized in the meters
column, we can begin to replace NAs.
sum(is.na(coffee$meters))
[1] 230
%>%
coffee filter(is.na(meters)) %>%
count(country_of_origin, region, sort = TRUE)
# A tibble: 67 × 3
country_of_origin region n
<chr> <chr> <int>
1 United States (Hawaii) kona 64
2 Colombia huila 18
3 Guatemala oriente 14
4 Colombia <NA> 9
5 Thailand <NA> 9
6 United States (Hawaii) <NA> 7
7 Brazil monte carmelo 6
8 Ethiopia sidamo 5
9 Brazil campos altos - cerrado 4
10 Brazil cerrado 4
# ℹ 57 more rows
%>%
coffee filter(country_of_origin == "United States (Hawaii)",
is.na(meters)) %>%
count(country_of_origin, region, sort = T)
# A tibble: 2 × 3
country_of_origin region n
<chr> <chr> <int>
1 United States (Hawaii) kona 64
2 United States (Hawaii) <NA> 7
In total we have 230 NAs for meters
, of which Hawaii accounts for 71 or 31%.
%>%
coffee select(id, country_of_origin, region, meters) %>%
filter(str_detect(country_of_origin, "(Hawaii)")) %>%
na.omit()
# A tibble: 2 × 4
id country_of_origin region meters
<int> <chr> <chr> <dbl>
1 14 United States (Hawaii) kona 186.
2 52 United States (Hawaii) kona 130.
Unfortunately, we only have two values for Hawaii. Verifying it with some googling, our two points of data look about right:
The Kona growing region is about 2 miles long and ranges in altitude from 600 ft. (183m) to 2500 ft (762m).
I’ll replace all NAs related to Hawaii with the mean of our existing data points. The function coalesce
fills the NA from the first vector with values from the second vector at corresponding positions.
<- coffee %>%
hawaii filter(str_detect(country_of_origin, "(Hawaii)")) %>%
select(id, meters) %>%
mutate(meters = replace_na(meters, (186+130)/2))
<- coffee %>%
coffee_refined left_join(hawaii, by = "id") %>%
mutate(meters = coalesce(meters.x, meters.y)) %>%
select(-meters.x, -meters.y)
%>%
coffee_refined filter(is.na(meters)) %>%
count(country_of_origin, region, sort = TRUE)
# A tibble: 65 × 3
country_of_origin region n
<chr> <chr> <int>
1 Colombia huila 18
2 Guatemala oriente 14
3 Colombia <NA> 9
4 Thailand <NA> 9
5 Brazil monte carmelo 6
6 Ethiopia sidamo 5
7 Brazil campos altos - cerrado 4
8 Brazil cerrado 4
9 Brazil <NA> 4
10 Ethiopia yirgacheffe 4
# ℹ 55 more rows
Let’s do the same for Huila (Colombia) and Oriente (Guatemala).
<- coffee %>%
huila filter(str_detect(region, "huila")) %>%
select(id, meters) %>%
mutate(meters = replace_na(meters, mean(meters, na.rm = TRUE)))
<- coffee_refined %>%
coffee_refined left_join(huila, by = "id") %>%
mutate(meters = coalesce(meters.x, meters.y)) %>%
select(-meters.x, -meters.y)
<- coffee %>%
oriente filter(str_detect(region, "oriente")) %>%
select(id, meters) %>%
mutate(meters = replace_na(meters, mean(meters, na.rm = TRUE)))
<- coffee_refined %>%
coffee_refined left_join(oriente, by = "id") %>%
mutate(meters = coalesce(meters.x, meters.y)) %>%
select(-meters.x, -meters.y)
We’ll replace the remaining missing values shortly using the recipes
package during our feature pre-processing stage.
3.6 Analyzing variety missingness
variety
appears to be another interesting column …
%>%
coffee_refined count(variety, sort = T)
# A tibble: 29 × 2
variety n
<chr> <int>
1 Caturra 255
2 <NA> 226
3 Bourbon 224
4 Typica 211
5 Other 109
6 Catuai 74
7 Hawaiian Kona 44
8 Yellow Bourbon 35
9 Mundo Novo 33
10 Catimor 20
# ℹ 19 more rows
… but missing values (NAs) are constitute nearly 17% of the data.
%>%
coffee_refined ::skim() %>%
skimrselect(skim_variable, complete_rate) %>%
filter(skim_variable == "variety")
# A tibble: 1 × 2
skim_variable complete_rate
<chr> <dbl>
1 variety 0.831
Let’s visualize the missingness of data in variety
. Specifically, is there a relationship between country_of_origin
and missing data in variety
?
%>%
coffee_refined group_by(country_of_origin) %>%
filter(sum(is.na(variety)) > 10) %>%
ungroup() %>%
ggplot(aes(total_cup_points, meters, colour = is.na(variety))) +
geom_point(size = 2, alpha = 0.5) +
scale_colour_fish_d(option = "Etheostoma_spectabile") +
facet_wrap(~ country_of_origin) +
labs(
x = "Total_cup_points",
y = "Meters",
colour = "Is Variety Missing?",
title = "Which countries contain the highest amount of missing data in Variety?",
subtitle = "Filtering for countries with more than 10 missing variety-related data entries",
caption = "Source: Coffee Quality Database"
+
) theme(plot.title = element_text(face = "bold", size = 20),
plot.subtitle = element_text(size = 17))
Most of the missing variety
data are from a select few countries - Colombia, Ethiopia, India, Thailand and Uganda. The missing values for variety
will also be addressed later on using recipes
.
3.7 Visualizing highest scores across countries
library(ggridges)
<- coffee_refined %>%
country_freq group_by(country_of_origin) %>%
add_count(name = "total_entry") %>%
ungroup() %>%
mutate(freq = total_entry / sum(n())) %>%
distinct(country_of_origin, .keep_all = TRUE) %>%
select(country_of_origin, freq) %>%
arrange(desc(freq)) %>%
slice(1:10) %>%
pull(country_of_origin)
%>%
coffee_refined filter(country_of_origin %in% country_freq,
> 75) %>%
total_cup_points mutate(country_of_origin = fct_reorder(country_of_origin, total_cup_points)) %>%
ggplot(aes(total_cup_points, country_of_origin, fill = country_of_origin)) +
geom_density_ridges(scale = 1, alpha = 0.8, show.legend = F) +
scale_fill_fish_d(option = "Antennarius_multiocellatus", begin = 0.5, end = 0) +
theme_ridges(center_axis_labels = TRUE) +
labs(
x = "Total Cup Points",
y = NULL,
fill = "",
title = "Visualizing Distribution of Coffee Ratings Across Countries",
subtitle = "Below: Top 10 countries based on absolute # of coffee ratings given,\nsorted according to score",
caption = "Source: Coffee Quality Database"
+
) theme(
plot.title = element_text(face = "bold", size = 20),
plot.subtitle = element_text(size = 17),
strip.background = element_blank(),
strip.text = element_text(face = "bold", size = 15),
legend.title = element_text(face = "bold", size = 15)
)
3.8 Finalizing the dataset
That was quite a bit of EDA just to replace missing values! Finalizing the data set, we have included:
- Unique identifier:
id
- Outcome (what we want to predict):
total_cup_points
- Predictors:
country_of_origin
,in_country_partner
,certification_body
,variety
,processing_method
,meters
Recall meters
is an engineered feature that was earlier created from various altitude column.
<- coffee_refined %>%
coffee_df select(id, total_cup_points,
country_of_origin,
in_country_partner, certification_body, variety, processing_method,#aroma:cupper_points,
%>%
meters) #converting all character columns to factors
mutate(across(where(is.character), as.factor))
3.8.1 Excluding aroma:cupper_points
Importantly, you might have realized I have chosen not to include columns aroma:cupper_points
; reason being because our outcome is the sum of these columns, and using these predictors will lead to all three models ignoring all other predictors that don’t below in these columns i.e. variable importance for these columns are overwhelmingly high.
Thus, the intention was that by excluding these columns, I wanted to make it more challenging for the models to predict the outcome - somewhat akin to the complexity faced in real-life predictive modeling.
4 Data preparation
The objective is to run three models: LASSO, Random Forest and XGboost, and compare performance in predicting total_cup_points
.
4.1 Initial Split
set.seed(2020)
<- initial_split(coffee_df)
coffee_split <- training(coffee_split)
coffee_train <- testing(coffee_split) coffee_test
4.2 Resampling
All three models will undergo hyperparameter tuning using crossfold validation. Here, I opt to use 10-fold cross validation.
set.seed(2020)
<- vfold_cv(coffee_train, v = 10) folds
5 Start with LASSO Model
5.1 Model Specification
Here I’m specifying the LASSO model I intend to fit. Hyperparameters tagged to tune()
will be subsequently tuned using a grid search.
<- linear_reg(penalty = tune(), mixture = 1) %>%
model_lasso set_engine("glmnet") %>%
set_mode("regression")
5.2 Feature Preprocessing
I came across this great illustration by Allison Horst describing the recipes
package:
Just a quick check on which columns have missing data:
%>%
coffee_train map_df(~ sum(is.na(.))) %>%
t()
[,1]
id 0
total_cup_points 0
country_of_origin 1
in_country_partner 0
certification_body 0
variety 160
processing_method 123
meters 95
Here we specify the recipe:
step_other
: Collapse factors into “other” if they don’t meet a predefined thresholdstep_dummy
: Turns nominal (character/factor) columns into numeric binary data. Necessary because the LASSO can only process numeric datastep_impute_knn
: Imputes the remainder of missing values using knn (default is 5). Here, after imputing missing values ofmeters
, I used it to impute missing values ofvariety
, and subsequently used both for imputing missing values ofprocessing_method
step_normalize
: Normalizes numeric data to have a standard deviation of one and a mean of zero. Necessary since the LASSO is sensitive to outliers
<- coffee_train %>%
coffee_rec recipe(total_cup_points ~ .) %>%
update_role(id, new_role = "id") %>%
step_other(
country_of_origin,
in_country_partner,
certification_body,
variety,
processing_method,threshold = 0.02,
other = "uncommon"
%>%
) step_unknown(country_of_origin, new_level = "unknown_country") %>%
step_dummy(all_nominal(), -variety, -processing_method) %>%
step_impute_knn(meters,
impute_with = imp_vars(contains(c(
"country", "certification"
%>%
)))) step_impute_knn(variety,
impute_with = imp_vars(contains(c(
"country", "certification", "meters"
%>%
)))) step_impute_knn(processing_method,
impute_with = imp_vars(contains(
c("country", "certification", "meters", "variety")
%>%
))) step_dummy(variety, processing_method) %>%
step_normalize(all_numeric(), -all_outcomes()) %>%
step_impute_knn(all_predictors())
prep()
estimates the required parameters from the training set, and juice()
applies these parameters on the training data and returns us the data in a tibble. The code below indicates that there are no more missing values after our pre-processing.
%>%
coffee_rec prep() %>%
juice() %>%
summarise(is_na = sum(is.na(across(everything()))))
# A tibble: 1 × 1
is_na
<int>
1 0
5.3 Workflows
The workflows
package introduces workflow objects that can help manage modeling pipelines more easily - akin to pieces that fit together like Lego blocks.
<- workflow() %>%
lasso_wf add_recipe(coffee_rec) %>%
add_model(model_lasso)
5.4 Hyperparameter Tuning
I’m setting up three respective grids for our three models. First up - for the LASSO model, I’ll be using grid_random
to generate a random grid.
set.seed(2020)
<- grid_random(penalty(), size = 50) lasso_grid
Once parallel processing has been set up, the tuning can now commence!
<- parallel::detectCores(logical = FALSE)
all_cores library(doParallel)
<- makePSOCKcluster(all_cores)
cl registerDoParallel(cl)
set.seed(2020)
<- tune_grid(
lasso_res object = lasso_wf,
resamples = folds,
grid = lasso_grid,
control = control_grid(save_pred = TRUE)
)
lasso_res
# Tuning results
# 10-fold cross-validation
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [900/100]> Fold01 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
2 <split [900/100]> Fold02 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
3 <split [900/100]> Fold03 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
4 <split [900/100]> Fold04 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
5 <split [900/100]> Fold05 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
6 <split [900/100]> Fold06 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
7 <split [900/100]> Fold07 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
8 <split [900/100]> Fold08 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
9 <split [900/100]> Fold09 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
10 <split [900/100]> Fold10 <tibble [100 × 5]> <tibble [2 × 3]> <tibble>
There were issues with some computations:
- Warning(s) x10: A correlation computation is required, but `estimate` is constant...
- Warning(s) x9: skipping variable with zero or non-finite range.
- Warning(s) x1: skipping variable with zero or non-finite range., ! The followin...
Run `show_notes(.Last.tune.result)` for more information.
5.5 Training Performance Assessment
The results can be obtained with collect_metrics()
and subsequently visualized; and the best tuned hyperparameters associated with the lowest in-sample RMSE can be obtained with show_best()
.
%>%
lasso_res collect_metrics() %>%
ggplot(aes(penalty, mean)) +
geom_line(aes(color = .metric),
linewidth = 1.5,
show.legend = FALSE) +
facet_wrap(. ~ .metric, nrow = 2) +
scale_x_log10(label = scales::number_format()) +
labs(
x = "Penalty",
y = "RMSE",
title = "LASSO: Assessing In-Sample Performance of Tuned Hyperparameters",
subtitle = "RMSE appears to be minimized when penalty is <0.01"
+
) theme(
plot.title = element_text(face = "bold", size = 20),
plot.subtitle = element_text(size = 17),
strip.background = element_blank(),
strip.text = element_text(face = "bold", size = 15),
)
%>%
lasso_res show_best(metric="rmse")
# A tibble: 5 × 7
penalty .metric .estimator mean n std_err .config
<dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.0183 rmse standard 2.42 10 0.0978 Preprocessor1_Model40
2 0.0153 rmse standard 2.42 10 0.0973 Preprocessor1_Model39
3 0.00954 rmse standard 2.42 10 0.0962 Preprocessor1_Model38
4 0.00441 rmse standard 2.42 10 0.0952 Preprocessor1_Model37
5 0.00303 rmse standard 2.42 10 0.0949 Preprocessor1_Model36
5.6 Finalizing Hyperparameters
Let’s use select_best()
to obtain the optimal penalty hyperparameters that minimizes RMSE, and finalize_workflow()
is used to fit the optimal hyperparameters to the LASSO model and the training data.
<- lasso_res %>%
lasso_best select_best(metric = "rmse")
lasso_best
# A tibble: 1 × 2
penalty .config
<dbl> <chr>
1 0.0183 Preprocessor1_Model40
<- lasso_wf %>%
lasso_final_wf finalize_workflow(lasso_best)
lasso_final_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ────────────────────────────────────────────────────────────────
9 Recipe Steps
• step_other()
• step_unknown()
• step_dummy()
• step_impute_knn()
• step_impute_knn()
• step_impute_knn()
• step_dummy()
• step_normalize()
• step_impute_knn()
── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)
Main Arguments:
penalty = 0.0182665677542162
mixture = 1
Computational engine: glmnet
6 Try Random Forest & XGBoost Models
Just like what was done for the LASSO model, the random forest & XGBoost models can be specificed the same way. The process is also similar when creating respective workflows, tuning hyperparameters, selecting the hyperparameters that corresponds to the lowest RMSE, and finalizing the workflow.
6.1 Model Specification
Let’s set trees = 1000
for both random forest and XGBoost, and tune the remaining hyperparameters.
<- rand_forest(mtry = tune(),
model_rf min_n = tune(),
trees = 1000) %>%
set_engine("ranger", importance = "permutation") %>%
set_mode("regression")
<- boost_tree(
model_xgboost trees = 1000,
#model complexity
tree_depth = tune(),
min_n = tune(),
loss_reduction = tune(),
#randomness
sample_size = tune(),
mtry = tune(),
#step size
learn_rate = tune()
%>%
) set_engine("xgboost") %>%
set_mode("regression")
6.2 Workflows
The beauty of tidymodels
is that we can conveniently re-use the same preprocessing recipe, coffee_rec
, in conjunction with the random forest and XGBoost model workflows.
<- workflow() %>%
rf_wf add_recipe(coffee_rec) %>%
add_model(model_rf)
<- workflow() %>%
xgb_wf add_recipe(coffee_rec) %>%
add_model(model_xgboost)
6.3 Hyperparameter Tuning
Similar to the LASSO, a grid_random
will be used for the random forest model. Note that finalize()
was used, together with our training set, to determine the upper-bound for our mtry()
hyperparameter (representing number of predictors that will be randomly sampled at each split when creating the tree models).
For the XGBoost model, however, we are using a space-filling latin hypercube grid that employs a statistical method for generating a near-random sample of parameter values from a multidimensional distribution.
set.seed(2020)
<- grid_random(finalize(mtry(), coffee_train), min_n(), size = 50)
rf_grid
set.seed(2020)
<- grid_latin_hypercube(
xgb_grid tree_depth(),
min_n(),
loss_reduction(),
sample_size = sample_prop(),
finalize(mtry(), coffee_train),
learn_rate(),
size = 50
)
With the grids set up, now we can tune both models’ hyperparameters.
<- parallel::detectCores(logical = FALSE)
all_cores library(doParallel)
<- makePSOCKcluster(all_cores)
cl registerDoParallel(cl)
set.seed(2020)
<- tune_grid(
rf_res object = rf_wf,
resamples = folds,
grid = rf_grid,
control = control_grid(save_pred = TRUE)
)
set.seed(2020)
<- tune_grid(
xgb_res object = xgb_wf,
resamples = folds,
grid = xgb_grid,
control = control_grid(save_pred = TRUE)
)
6.4 Training Performance Assessment
We can also visually assess the performance across all tuned hyperparameters and their effect on RMSE for the random forest model
%>%
rf_res collect_metrics() %>%
filter(.metric == "rmse") %>%
pivot_longer(mtry:min_n, names_to = "parameter", values_to = "value") %>%
ggplot(aes(value, mean)) +
geom_point(aes(color = parameter),
size = 2,
show.legend = FALSE) +
facet_wrap(. ~ parameter, scales = "free_x") +
labs(
x = "",
y = "RMSE",
title = "Random Forest: Assessing In-Sample Performance of Tuned Hyperparameters",
subtitle = "RMSE appears to be minimized at low levels of min_n and mtry"
+
) theme(
plot.title = element_text(face = "bold", size = 20),
plot.subtitle = element_text(size = 17),
strip.background = element_blank(),
strip.text = element_text(face = "bold", size = 15),
)
%>%
rf_res show_best(metric = "rmse")
# A tibble: 5 × 8
mtry min_n .metric .estimator mean n std_err .config
<int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
1 6 14 rmse standard 2.39 10 0.105 Preprocessor1_Model04
2 6 18 rmse standard 2.39 10 0.105 Preprocessor1_Model40
3 5 11 rmse standard 2.39 10 0.104 Preprocessor1_Model36
4 7 17 rmse standard 2.39 10 0.105 Preprocessor1_Model32
5 8 8 rmse standard 2.39 10 0.107 Preprocessor1_Model19
Let’s do the same for the XGBoost model.
%>%
xgb_res collect_metrics() %>%
filter(.metric == "rmse") %>%
select(mean, mtry:sample_size) %>%
pivot_longer(mtry:sample_size, names_to = "parameter", values_to = "value") %>%
ggplot(aes(value, mean)) +
geom_point(aes(color = parameter),
size = 2,
show.legend = FALSE) +
facet_wrap(. ~ parameter, scales = "free_x") +
labs(
x = "",
y = "RMSE",
title = "XGBoost: Assessing In-Sample Performance of Tuned Hyperparameters",
subtitle = "Several combinations of parameters do well to minimize RMSE"
+
) theme(
plot.title = element_text(face = "bold", size = 20),
plot.subtitle = element_text(size = 17),
strip.background = element_blank(),
strip.text = element_text(face = "bold", size = 15),
)
%>%
xgb_res show_best(metric = "rmse")
# A tibble: 5 × 12
mtry min_n tree_depth learn_rate loss_reduction sample_size .metric
<int> <int> <int> <dbl> <dbl> <dbl> <chr>
1 4 6 14 0.0225 4.08e- 1 0.797 rmse
2 5 11 3 0.0810 3.43e- 4 0.948 rmse
3 6 29 13 0.0336 2.95e+ 1 0.465 rmse
4 8 28 2 0.00648 1.20e-10 0.804 rmse
5 2 24 4 0.0105 1.01e- 5 0.592 rmse
# ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
# .config <chr>
6.5 Finalizing Hyperparameters
The workflows for both the random forest and XGBoost models are now finalized.
<- rf_res %>%
rf_best select_best(metric = "rmse")
<- xgb_res %>%
xgb_best select_best(metric = "rmse")
<- rf_wf %>%
rf_final_wf finalize_workflow(rf_best)
<- xgb_wf %>%
xgb_final_wf finalize_workflow(xgb_best)
7 Model Evaluation: Variable Importance
Before running the models on the test data, let’s compare variable importance for our models. It can be useful know which, if any, of the predictors in a fitted model are relatively influential on the predicted outcome.
library(vip) #variable importance plots
library(patchwork) #combining plots
<- lasso_final_wf %>%
p1 fit(coffee_train) %>%
extract_fit_parsnip() %>%
vip(geom = "point") +
ggtitle("LASSO")
<- rf_final_wf %>%
p2 fit(coffee_train) %>%
extract_fit_parsnip() %>%
vip(geom = "point") +
ggtitle("Random Forest")
<- xgb_final_wf %>%
p3 fit(coffee_train) %>%
extract_fit_parsnip() %>%
vip(geom = "point") +
ggtitle("XGBoost")
+ p2 / p3 + plot_annotation(
p1 title = 'Assessing Variable Importance Across Models',
subtitle = 'The engineered feature \"meters\" is heavily favoured by the tree-based models',
theme = theme(
plot.title = element_text(size = 18),
plot.subtitle = element_text(size = 14)
) )
Earlier, Ethiopia stood out as the country with the highest mean scores when visualizing coffee rating scores across countries earlier. Here, it makes sense to see country_of_origin_Ethiopia
having relatively high variable importance for the LASSO and XGBoost model. The engineered feature, meters
, has a significant variable importance contribution too for our tree-based models.
8 Final Model Selection
As a recap, here are the corresponding RMSE values for the set of hyperparameters that was selected for our models.
%>%
lasso_res show_best(metric = "rmse") %>%
mutate(model = "lasso") %>%
bind_rows(rf_res %>%
show_best(metric = "rmse") %>%
mutate(model = "randomforest")) %>%
bind_rows(xgb_res %>%
show_best(metric = "rmse") %>%
mutate(model = "xgboost")) %>%
group_by(model) %>%
summarise(lowest_training_rmse = round(min(mean), 2))
# A tibble: 3 × 2
model lowest_training_rmse
<chr> <dbl>
1 lasso 2.42
2 randomforest 2.39
3 xgboost 2.42
As the final step, for all three models, we perform a last_fit
using the split data, coffee_split
.
This seeks to emulates the process where, after determining the best model, the final fit on the entire training set is used to evaluate the test set, coffee_test
(which has not been touched since the initial split).
%>%
lasso_final_wf last_fit(coffee_split) %>%
collect_metrics() %>%
mutate(model = "lasso") %>%
bind_rows(
%>%
rf_final_wf last_fit(coffee_split) %>%
collect_metrics() %>%
mutate(model = "randomforest")
%>%
) bind_rows(
%>%
xgb_final_wf last_fit(coffee_split) %>%
collect_metrics() %>%
mutate(model = "xgboost")
%>%
) filter(.metric =="rmse")
# A tibble: 3 × 5
.metric .estimator .estimate .config model
<chr> <chr> <dbl> <chr> <chr>
1 rmse standard 2.58 Preprocessor1_Model1 lasso
2 rmse standard 2.52 Preprocessor1_Model1 randomforest
3 rmse standard 2.48 Preprocessor1_Model1 xgboost
9 Conclusion
Test results for all three models are slightly higher than their training scores, which might indicate some overfitting. Both tree-based models also performed slightly better than the LASSO model, which could mean there are interaction effects at play.
While I’d still go with either tree-based model, the difference in test RMSE is so close that I’d still be inclined to compare all three models’ performance in future.