Home > Articles > Trustworthy Models in Practice: a Simple Linear Approach

Trustworthy Models in Practice: a Simple Linear Approach

Last time, we began to talk about how to build models worthy of our users' trust. As a refresher, we said that trustworthy models require at least three things:

  1. Prediction -- An estimate for some unknown value
  2. Confidence -- A description of how uncertain the model is about the prediction
  3. Explanation -- The reasoning for which a model made its prediction

Today, we'll take a pass at actually implementing such a model.

Dataset

For pedagogical reasons, we're using a dataset on fish that were sold at a fish market. Here's a few rows from the dataset:

| Species | Weight | Length1 | Length2 | Length3 | Height  | Width  |
|---------|--------|---------|---------|---------|---------|--------|
| Perch   | 250.0  | 25.9    | 28.0    | 29.4    | 7.8204  | 4.2042 |
| Bream   | 714.0  | 32.7    | 36.0    | 41.5    | 16.517  | 5.8515 |
| Perch   | 145.0  | 22.0    | 24.0    | 25.5    | 6.375   | 3.825  |
| Perch   | 145.0  | 20.7    | 22.7    | 24.2    | 5.9532  | 3.63   |
| Bream   | 975.0  | 37.4    | 41.0    | 45.9    | 18.6354 | 6.7473 |

The first step, of course, is to load it up!

import os
import pandas as pd

fish = pd.read_csv(os.path.expanduser("~/Downloads/Fish.csv"))

Building a model

For our exercise today, let's see if we can predict Weight given the values of the other columns. We're going to use statsmodels to build a simple linear model.

import statsmodels.formula.api as smf

model = smf.ols(
    formula="Weight ~ C(Species) + Length2 + Length2 + Length3 + Height + Width",
    data=fish,
).fit()

If you've never used statsmodels before, think of this as fitting a linear model, with Species being one-hot encoded. statsmodels has a nice way of getting basic information about the model:

model.summary()
                            OLS Regression Results
==============================================================================
Dep. Variable:                 Weight   R-squared:                       0.936
Model:                            OLS   Adj. R-squared:                  0.931
Method:                 Least Squares   F-statistic:                     195.7
Date:                Sun, 14 Jun 2020   Prob (F-statistic):           6.85e-82
Time:                        15:00:23   Log-Likelihood:                -941.46
No. Observations:                 159   AIC:                             1907.
Df Residuals:                     147   BIC:                             1944.
Df Model:                          11
Covariance Type:            nonrobust
===========================================================================================
                              coef    std err          t      P>|t|      [0.025      0.975]
-------------------------------------------------------------------------------------------
Intercept                -918.3321    127.083     -7.226      0.000   -1169.478    -667.186
C(Species)[T.Parkki]      164.7227     75.699      2.176      0.031      15.123     314.322
C(Species)[T.Perch]       137.9489    120.314      1.147      0.253     -99.819     375.717
C(Species)[T.Pike]       -208.4294    135.306     -1.540      0.126    -475.826      58.968
C(Species)[T.Roach]       103.0400     91.308      1.128      0.261     -77.407     283.487
C(Species)[T.Smelt]       446.0733    119.430      3.735      0.000     210.051     682.095
C(Species)[T.Whitefish]    93.8742     96.658      0.971      0.333     -97.145     284.893
Length1                   -80.3030     36.279     -2.214      0.028    -151.998      -8.608
Length2                    79.8886     45.718      1.747      0.083     -10.461     170.238
Length3                    32.5354     29.300      1.110      0.269     -25.369      90.439
Height                      5.2510     13.056      0.402      0.688     -20.551      31.053
Width                      -0.5154     23.913     -0.022      0.983     -47.773      46.742
==============================================================================
Omnibus:                       43.558   Durbin-Watson:                   0.973
Prob(Omnibus):                  0.000   Jarque-Bera (JB):               97.422
Skew:                           1.184   Prob(JB):                     7.00e-22
Kurtosis:                       6.016   Cond. No.                     2.03e+03
==============================================================================

Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
[2] The condition number is large, 2.03e+03. This might indicate that there are
strong multicollinearity or other numerical problems.

At this point, we can achieve our first objective: to provide a prediction!

new_fish = pd.DataFrame(
    [
        {
            "Species": "Bream",
            "Weight": -1,
            "Length1": 31.3,
            "Length2": 34,
            "Length3": 39.5,
            "Height": 15.1285,
            "Width": 5.5695,
        }
    ]
)
model.predict(new_fish)

This model predicts this fish weighs about 646 grams.

Providing uncertainty

The main reason I've chosen to use statsmodels (rather thank scikit-learn) is that it provides built-in support for prediction intervals. Take a look:

frame = model.get_prediction(new_fish).summary_frame(alpha=0.95)
frame.round(2)
| mean   | mean_se | mean_ci_lower | mean_ci_upper | obs_ci_lower | obs_ci_upper |
|--------|---------|---------------|---------------|--------------|--------------|
| 646.12 | 18.32   | 644.96        | 647.27        | 640.11       | 652.12      |

mean here is the prediction, and a 95% prediction interval is provided by obs_ci_lower and obs_ci_upper. In other words, our model thinks the weight of this fish is between 640 and 652 grams with 95% probability.

We're two thirds of the way there!

Providing an explanation

We can use the structure of the model to provide an explanation. The prediction is equal to:

  -918          (the intercept)
-   80.3 * 31.3 (Length1)
+   79.9 * 34   (Length2)
+   32.5 * 39.5 (Length3)
+    5.3 * 15.1 (Height)
-    0.5 *  5.6 (Width)
   ------------
   646.12

A way we might display how the various features contribute to the overall prediction is this:

def fish_to_feats(a_fish, model):
    feats = a_fish.copy()
    feats["Intercept"] = 1.0
    for species_feat in model.params.index:
        if not species_feat.startswith("C(Species)"):
            continue
        species = species_feat.split(".")[1].replace("]", "")  # This is ugly
        feats[species_feat] = (feats["Species"] == species).astype(int)

    del feats["Species"]
    return feats[model.params.index]


contributions = fish_to_feats(new_fish, model) * model.params
for name, amount in sorted(
    contributions.round(1).iteritems(), key=lambda t: -t[1].abs()[0]
):
    if -1e-3 < amount[0] < 1e3:
        continue
    print(f"{name}: {amount[0]}")

Which provides the following output:

Length2: 2716.2
Length1: -2513.5
Length3: 1285.1
Intercept: -918.3
Width: -2.9

This could certainly be made more user friendly, but it does give some kind of explanation for why the model believes this fish to weigh 646 grams.

Conclusion

We've built a model that can provide trustworthy predictions. For example:

  1. My best guess at the weight of this Bream is 646g.
  2. With 95% probability, the weight is between 640g and 652g.
  3. The biggest contributors to this prediction are Length2 (pushes the prediction higher), Length1 (pushes it lower), and Length3 (pushes it higher).

I highly recommend attacking machine learning problems by starting with an incredibly simple model first. Implementing that end-to-end enables focus on the truly difficult parts of machine learning (i.e. not the ML bits). For some use cases, this post provides yet another reason to love linear models: they are trustworthy by default!


Comments? Questions? Concerns? Please tweet me @SamuelDataT or email me. Thanks!