4 Fitting basics
library(tidyverse)
library(rstanarm)
library(tidymodels)
tidymodels_prefer()
In Chapters 2 and 3, we used EDA to understand the diamonds
dataset and to identify function families with which to model the relationship of price
to the four Cs, carat
, clarity
, color
, and cut
. We’re now ready to fit models. To start, we’ll use the simplest function family we identified, the power law relationships for price
vs. carat
, which are expressed by the formula
log(price) ~ log(carat)
.
In the next chapter, we’ll model this data with other function families and show you how to compare and evaluate different models.
We’ve made several modifications to the diamonds
dataset over the past two chapters, including removing diamonds with impossible dimensions and filtering out very large diamonds. The following code repeats these manipulations and stores the result in a new tibble, df
.
<-
df %>%
diamonds filter(x > 0, y > 0, z > 0) %>%
filter(y < 20, z < 10) %>%
filter(carat <= quantile(.$carat, probs = 0.99)) %>%
mutate(
color = fct_rev(color),
color_combined = fct_collapse(color, "DEFG" = c("D", "E", "F", "G"))
)
Going forward, we’ll use df
instead of diamonds
.
4.1 Fitting a model
In Chapter 1, we discussed classical and Bayesian modeling. We’ll now show you how to actually fit classical and Bayesian models, both directly and with the tidymodels modeling framework.
4.1.1 Fitting a classical model
The R function lm()
fits linear models using classical least squares. Here’s how to use it to fit a model with our function family and data.
<- lm(log(price) ~ log(carat), data = df) fit_lm
You can learn about the fit by simply printing it.
fit_lm#>
#> Call:
#> lm(formula = log(price) ~ log(carat), data = df)
#>
#> Coefficients:
#> (Intercept) log(carat)
#> 8.46 1.69
You can extract the coefficients for the model with coef()
.
coef(fit_lm)
#> (Intercept) log(carat)
#> 8.46 1.69
4.1.2 Fitting a Bayesian model
Stan is a platform for statistical modeling, including Bayesian modeling. rstanarm is an “R package that emulates other R model-fitting functions but uses Stan … for the back-end estimation.”
stan_glm()
is an rstanarm function that emulates the R function glm()
to fit generalized linear models. Here’s how to use it to fit a model with our function family and data.
<-
fit_stan_glm stan_glm(log(price) ~ log(carat), data = df, refresh = 0, seed = 505)
stan_glm()
uses a probabilistic algorithm. If you wish, you can use the seed
argument to set the starting point of the random number generator used by the algorithm so that it will return the same results every time it is called. The refresh = 0
argument suppresses information that the algorithm would otherwise print while it is running.
You can learn about the fit by simply printing it.
fit_stan_glm#> stan_glm
#> family: gaussian [identity]
#> formula: log(price) ~ log(carat)
#> observations: 53405
#> predictors: 2
#> ------
#> Median MAD_SD
#> (Intercept) 8.5 0.0
#> log(carat) 1.7 0.0
#>
#> Auxiliary parameter(s):
#> Median MAD_SD
#> sigma 0.3 0.0
#>
#> ------
#> * For help interpreting the printed output see ?print.stanreg
#> * For info on the priors used see ?prior_summary.stanreg
By default, this only prints one digit of precision. Here’s how you can print more.
print(fit_stan_glm, digits = 3)
#> stan_glm
#> family: gaussian [identity]
#> formula: log(price) ~ log(carat)
#> observations: 53405
#> predictors: 2
#> ------
#> Median MAD_SD
#> (Intercept) 8.457 0.001
#> log(carat) 1.688 0.002
#>
#> Auxiliary parameter(s):
#> Median MAD_SD
#> sigma 0.260 0.001
#>
#> ------
#> * For help interpreting the printed output see ?print.stanreg
#> * For info on the priors used see ?prior_summary.stanreg
The parameter sigma
is an estimate of the standard deviation of the residual distribution, which is assumed to be normal with mean 0. The MAD_SD
column is the scaled median absolute deviation, which provides a robust measure of variation of the parameter distributions.
An advantage of Bayesian modeling is that it provides estimates not just of the values of the parameters, but also an estimate of their joint distribution through simulations. Here’s how you can access these simulations.
<-
sims as.data.frame(fit_stan_glm) %>%
as_tibble()
By default, stan_glm()
performs 4000 simulations from the joint distribution of parameters.
sims#> # A tibble: 4,000 × 3
#> `(Intercept)` `log(carat)` sigma
#> <dbl> <dbl> <dbl>
#> 1 8.46 1.68 0.260
#> 2 8.46 1.69 0.262
#> 3 8.46 1.69 0.262
#> 4 8.46 1.69 0.261
#> 5 8.46 1.69 0.261
#> 6 8.46 1.69 0.261
#> # … with 3,994 more rows
These simulations can be used to make inferences about the model, such as uncertainties in the parameters or in the model’s predictions.
You can extract the coefficients for the model with coef()
.
coef(fit_stan_glm)
#> (Intercept) log(carat)
#> 8.46 1.69
The coefficients produced by lm()
and stan_glm()
are very close.
coef(fit_lm) - coef(fit_stan_glm)
#> (Intercept) log(carat)
#> -4.91e-05 -5.11e-05
4.1.3 Tidymodels
Tidymodels is an ecosystem of R packages for modeling. It provides a common interface to a growing number of model types and engines. And it supports a range of other modeling tasks, such as data preprocessing, resampling, and parameter tuning. You can learn more at Tidy Modeling with R.
Linear regression, the type of model we’ll be using here, is just one of several model types supported by tidymodels. Each type of model supported by tidymodels may have multiple computational engines. Here are the engines currently available for linear regression.
show_engines("linear_reg")
#> # A tibble: 7 × 2
#> engine mode
#> <chr> <chr>
#> 1 lm regression
#> 2 glm regression
#> 3 glmnet regression
#> 4 stan regression
#> 5 spark regression
#> 6 keras regression
#> # … with 1 more row
Tidymodels provides a common interface for each type of model, so you don’t need to know the details and arguments for each computational engine. This makes it easier to model a dataset using different types of model, and using different computational engines for the same type of model.
Here’s how to fit our function family and data with tidymodels using lm()
.
<-
fit_tm_lm linear_reg() %>%
set_engine("lm") %>%
fit(log(price) ~ log(carat), data = df)
fit_tm_lm#> parsnip model object
#>
#>
#> Call:
#> stats::lm(formula = log(price) ~ log(carat), data = data)
#>
#> Coefficients:
#> (Intercept) log(carat)
#> 8.46 1.69
The tidymodels fit contains the lm()
fit.
$fit
fit_tm_lm#>
#> Call:
#> stats::lm(formula = log(price) ~ log(carat), data = data)
#>
#> Coefficients:
#> (Intercept) log(carat)
#> 8.46 1.69
Tidymodels uses stan_glm()
to fit linear regressions with the stan engine. Here’s how to fit our function family and data with tidymodels using stan_glm()
.
<-
fit_tm_stan linear_reg() %>%
set_engine("stan", seed = 505) %>%
fit(log(price) ~ log(carat), data = df)
The seed
argument isn’t required. We used it so the results would match the direct call of stan_glm()
above. With tidymodels, refresh = 0
is the default.
Here’s information about the fit.
fit_tm_stan#> parsnip model object
#>
#> stan_glm
#> family: gaussian [identity]
#> formula: log(price) ~ log(carat)
#> observations: 53405
#> predictors: 2
#> ------
#> Median MAD_SD
#> (Intercept) 8.5 0.0
#> log(carat) 1.7 0.0
#>
#> Auxiliary parameter(s):
#> Median MAD_SD
#> sigma 0.3 0.0
#>
#> ------
#> * For help interpreting the printed output see ?print.stanreg
#> * For info on the priors used see ?prior_summary.stanreg
The tidymodels fit contains the stan_glm()
fit.
print(fit_tm_stan$fit, digits = 3)
#> stan_glm
#> family: gaussian [identity]
#> formula: log(price) ~ log(carat)
#> observations: 53405
#> predictors: 2
#> ------
#> Median MAD_SD
#> (Intercept) 8.457 0.001
#> log(carat) 1.688 0.002
#>
#> Auxiliary parameter(s):
#> Median MAD_SD
#> sigma 0.260 0.001
#>
#> ------
#> * For help interpreting the printed output see ?print.stanreg
#> * For info on the priors used see ?prior_summary.stanreg
4.2 Checking a model
In the last chapter, we used EDA to find promising function families with which to model the diamonds
dataset. We can now use EDA to check to see how well a model fits the data.
Here’s a plot of the data we’re modeling and a smooth line.
%>%
df ggplot(aes(carat, price)) +
geom_point(alpha = 0.01) +
geom_smooth()
Recall that the formula for our function family
log(price) ~ log(carat)
represents functions of the form
log(price) = a_0 + a_1 * log(carat)
for parameters a_0
and a_1
.
From the coefficients we saw above, the algorithms chose the function
log(price) = 8.46 + 1.69 * log(carat)
or, after we apply exp()
to both sides,
price = 4707 * carat^1.69
.
The smooth line above curves upward, indicating a growth in price
that is greater than linear. The model exponent is 1.69. Since this is larger than 1, it likewise represents a growth greater than linear.
This function predicts that a one-carat diamond would have a price of about 4707. Here’s the distribution of actual prices of one-carat diamonds.
%>%
df filter(near(carat, 1)) %>%
pull(price) %>%
summary()
#> Min. 1st Qu. Median Mean 3rd Qu. Max.
#> 1681 4155 4864 5243 6079 16469
As you can see from the plot and the summary, there is a wide variation in prices for one-carat diamonds, but the prediction is roughly comparable with the median and mean prices.
With no glaring problems with the parameters, we’ll next look at predictions across the full range of carat
.
4.2.1 Making predictions
Once you’ve fit a model, you can use the fit to make predictions. Here’s how to use the predict()
function to make predictions for the lm()
and stan_glm()
models.
<-
preds tibble(
carat = seq(min(df$carat), max(df$carat), length.out = 801),
pred_lm = predict(fit_lm, newdata = tibble(carat)) %>% exp(),
pred_stan_glm = predict(fit_stan_glm, newdata = tibble(carat)) %>% exp()
)
preds#> # A tibble: 801 × 3
#> carat pred_lm pred_stan_glm
#> <dbl> <dbl> <dbl>
#> 1 0.2 311. 311.
#> 2 0.202 317. 317.
#> 3 0.205 324. 324.
#> 4 0.207 331. 331.
#> 5 0.210 337. 337.
#> 6 0.212 344. 344.
#> # … with 795 more rows
In the tibble above, carat
contains equally spaced values. We then use the newdata
argument of predict()
to make predictions at these values using the models. predict()
makes predictions of log(price)
, since this was the response in the formula defining the model function family. We therefore use exp()
to get predictions for price
.
lm()
is from the stats package, and stan_glm()
is from the rstanarm package. These packages were written and are maintained by different people. You might wonder how the function predict()
could make predictions for such very different models from these two different packages. It does this by taking advantage for R’s object-oriented functionality.
fit_lm
is has class lm.
class(fit_lm)
#> [1] "lm"
When predict()
encounters a model of this class, it calls predict.lm()
from the stats package.
fit_stan_glm
has class stanreg.
class(fit_stan_glm)
#> [1] "stanreg" "glm" "lm"
When predict()
encounters a model of this class, it calls predict.stanreg()
from the rstanarm package.
The tidymodels models fit_tm_lm
and fit_tm_stan
have class model_fit
class(fit_tm_lm)
#> [1] "_lm" "model_fit"
class(fit_tm_stan)
#> [1] "_stanreg" "model_fit"
When predict()
encounters a model of this class, it calls predict.model_fit()
from the tidymodels parsnip package. We need to use slightly different arguments to make predictions with this function.
<-
preds %>%
preds mutate(
pred_tm_lm =
predict(fit_tm_lm, new_data = tibble(carat), type = "raw") %>% exp(),
pred_tm_stan =
predict(fit_tm_stan, new_data = tibble(carat), type = "raw") %>% exp()
)
preds#> # A tibble: 801 × 5
#> carat pred_lm pred_stan_glm pred_tm_lm pred_tm_stan
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 0.2 311. 311. 311. 311.
#> 2 0.202 317. 317. 317. 317.
#> 3 0.205 324. 324. 324. 324.
#> 4 0.207 331. 331. 331. 331.
#> 5 0.210 337. 337. 337. 337.
#> 6 0.212 344. 344. 344. 344.
#> # … with 795 more rows
The predictions by the models created with tidymodels are exactly the same as those created directly with lm()
and stan_glm()
.
%>%
preds summarize(
diff_lm = max(abs(pred_tm_lm - pred_lm)),
diff_stan_glm = max(abs(pred_tm_stan - pred_stan_glm))
)#> # A tibble: 1 × 2
#> diff_lm diff_stan_glm
#> <dbl> <dbl>
#> 1 0 0
And the predictions by the models created by directly lm()
and stan_glm()
are very close.
<-
v %>%
preds summarize(
rel_max = max(abs(pred_stan_glm - pred_lm) / pmin(pred_stan_glm, pred_lm))
)
v#> # A tibble: 1 × 1
#> rel_max
#> <dbl>
#> 1 0.0000659
The predictions are always within 0.0066% of each other.
Since the predictions of the different models are all very close, in the following we will only check the fit of stan_glm()
.
4.2.2 Checking predictions
Recall from Chapter 3 that log-log plots are an effective way to plot data from a power law. On a log-log plot, a power law function is linear.
Here’s our data on a log-log plot with a smooth line and our model’s predictions.
%>%
df ggplot(aes(carat, price)) +
geom_point(alpha = 0.01) +
geom_smooth(aes(color = "Data")) +
geom_line(aes(y = pred_stan_glm, color = "Model"), data = preds, size = 1) +
scale_x_log10() +
scale_y_log10() +
scale_color_discrete(direction = -1) +
theme(legend.position = "top") +
labs(color = NULL)
For diamonds greater than or equal to 0.3 carat in size, the model predictions are very close to the smooth line. For diamonds smaller than 0.3 carat, the actual prices of diamonds are greater than the prices predicted by the model. However, there are relatively few diamonds of this size.
Since stan_glm()
is a Bayesian model, we can plot estimates of the uncertainty in its predictions.
<- function(.data, fit, probs = c(0.5, 0.9)) {
predictive_intervals %>%
.data mutate(.pred = predict(fit, newdata = .)) %>%
bind_cols(
map_dfc(
probs, ~ predictive_interval(fit, prob = ., newdata = .data) %>%
as_tibble()
)
)
}
<-
v tibble(carat = seq(min(df$carat), max(df$carat), length.out = 801)) %>%
predictive_intervals(fit = fit_stan_glm) %>%
mutate(across(!carat, exp))
%>%
v ggplot(aes(carat)) +
geom_ribbon(aes(ymin = `5%`, ymax = `95%`), alpha = 0.25) +
geom_ribbon(aes(ymin = `25%`, ymax = `75%`), alpha = 0.5) +
geom_point(aes(y = price), data = df, alpha = 0.01) +
geom_line(aes(y = .pred), color = "#f8766d", size = 1) +
coord_cartesian(ylim = c(min(df$price), max(df$price))) +
scale_x_log10() +
scale_y_log10() +
labs(
title = "Bayesian linear regression with 50% and 90% predictive intervals"
)
The 90% predictive interval contains the bulk of the data.
With no major discrepancies between the model’s predictions and the data, we’ll now turn to a more sophisticated technique from EDA – checking the model’s residuals.
4.2.3 Checking residuals
We encountered residuals in Chapter 1. For a point in a dataset, the residual at that point is the true value of the response variable minus the value predicted using the corresponding predictor variable(s).
Our formula uses logs, so the residual for a logged response is
log(price) - log(predicted price)
or equivalently
log(price / predicted price)
.
Applying exp()
to this residual turns an additive error into a multiplicative error. This is the natural way to study error for a power law, where the range of the response variable can vary greatly. For example, in our dataset price
ranges from $326 to $18,818. We will therefore examine the ratio
price / predicted price
.
If we call the predict()
function without specifying new data, it will return the predictions for the dataset used to fit the model.
<-
v %>%
df select(price, carat) %>%
mutate(
pred = predict(fit_stan_glm) %>% exp(),
error_ratio = price / pred
)
v#> # A tibble: 53,405 × 4
#> price carat pred error_ratio
#> <int> <dbl> <dbl> <dbl>
#> 1 326 0.23 394. 0.828
#> 2 326 0.21 338. 0.965
#> 3 327 0.23 394. 0.831
#> 4 334 0.29 582. 0.574
#> 5 335 0.31 652. 0.514
#> 6 336 0.24 423. 0.794
#> # … with 53,399 more rows
Let’s look at a plot of the error ratios.
%>%
v ggplot(aes(carat, error_ratio)) +
geom_point(alpha = 0.01) +
geom_hline(yintercept = 1, size = 1, color = "white") +
geom_smooth() +
coord_cartesian(ylim = c(0, NA)) +
scale_x_log10() +
labs(y = "Error ratio")
As we saw above, the actual prices of the diamonds were more expensive than the model’s predictions for those with size less that 0.3 carat. In addition, we can now see that diamonds around 1.0 and 1.5 carats were also more expensive than predicted. Nevertheless, the smooth line for error_ratio
is remains fairly close to 1 for diamonds of size 0.3 carat and greater. A systematic divergence from this line would indicate a problem with the fit.
Let’s look at the distribution of error_ratio
.
quantile(v$error_ratio, probs = c(0.025, 0.05, 0.5, 0.95, 0.975))
#> 2.5% 5% 50% 95% 97.5%
#> 0.614 0.662 0.994 1.538 1.681
The median error_ratio
is very close to 1. Approximately 95% of the diamonds are within the range of 39% less and 68% more than the predictions.
4.3 Summary
The process of fitting classical and Bayesian models is similar. Bayesian models have the advantage of enabling additional inferences about the model, such as uncertainties in the parameters or the model’s predictions. Tidymodels provides a common interface to a wide range of model types and engines, as well as other tools for the modeling process.
Once we have fit a model, we can check the fit by:
- Checking the model parameters against known features of the data.
- Checking the predictions of the model against the actual values of the response variable in the data.
- Checking the residuals to see if their smooth line is largely horizontal.
These checks only indicate how well the model fits the data. They are not an indication of how well the model would make predictions with new data. We will turn to that question in the next chapter, as well as the issue of how to evaluate and compare models from different function families.