Metric types

Davis Vaughan

2019-03-08

Metric types

There are three main metric types in yardstick: class, class probability, and numeric. Each type of metric has standardized argument syntax, and all metrics return the same kind of output (a tibble with 3 columns). This standardization allows metrics to easily be grouped together and used with grouped data frames for computing on multiple resamples at once. Below are the three types of metrics, along with the types of the inputs they take.

  1. Class metrics (hard predictions)

    • truth - factor

    • estimate - factor

  2. Class probability metrics (soft predictions)
    • truth - factor

    • estimate / ... - multiple numeric columns containing class probabilities

  3. Numeric metrics

    • truth - numeric

    • estimate - numeric

Example

In the following example, the hpc_cv data set is used. It contains class probabilities and class predictions for a linear discriminant analysis fit to the HPC data set of Kuhn and Johnson (2013). It is fit with 10 fold cross-validation, and the predictions for all folds are included.

library(yardstick)
library(dplyr)
data("hpc_cv")

hpc_cv %>%
  group_by(Resample) %>%
  slice(1:3)
#> # A tibble: 30 x 7
#> # Groups:   Resample [10]
#>    obs   pred     VF      F       M          L Resample
#>    <fct> <fct> <dbl>  <dbl>   <dbl>      <dbl> <chr>   
#>  1 VF    VF    0.914 0.0779 0.00848 0.0000199  Fold01  
#>  2 VF    VF    0.938 0.0571 0.00482 0.0000101  Fold01  
#>  3 VF    VF    0.947 0.0495 0.00316 0.00000500 Fold01  
#>  4 VF    VF    0.941 0.0544 0.00441 0.0000123  Fold02  
#>  5 VF    VF    0.948 0.0483 0.00347 0.00000792 Fold02  
#>  6 VF    VF    0.958 0.0395 0.00236 0.00000310 Fold02  
#>  7 VF    VF    0.939 0.0556 0.00513 0.00000790 Fold03  
#>  8 VF    VF    0.928 0.0642 0.00777 0.0000148  Fold03  
#>  9 VF    VF    0.927 0.0653 0.00786 0.0000150  Fold03  
#> 10 VF    VF    0.949 0.0469 0.00398 0.00000935 Fold04  
#> # … with 20 more rows

1 metric, 1 resample

hpc_cv %>%
  filter(Resample == "Fold01") %>%
  accuracy(obs, pred)
#> # A tibble: 1 x 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy multiclass     0.726

1 metric, 10 resamples

hpc_cv %>%
  group_by(Resample) %>%
  accuracy(obs, pred)
#> # A tibble: 10 x 4
#>    Resample .metric  .estimator .estimate
#>    <chr>    <chr>    <chr>          <dbl>
#>  1 Fold01   accuracy multiclass     0.726
#>  2 Fold02   accuracy multiclass     0.712
#>  3 Fold03   accuracy multiclass     0.758
#>  4 Fold04   accuracy multiclass     0.712
#>  5 Fold05   accuracy multiclass     0.712
#>  6 Fold06   accuracy multiclass     0.697
#>  7 Fold07   accuracy multiclass     0.675
#>  8 Fold08   accuracy multiclass     0.721
#>  9 Fold09   accuracy multiclass     0.673
#> 10 Fold10   accuracy multiclass     0.699

2 metrics, 10 resamples

class_metrics <- metric_set(accuracy, kap)

hpc_cv %>%
  group_by(Resample) %>%
  class_metrics(obs, estimate = pred)
#> # A tibble: 20 x 4
#>    Resample .metric  .estimator .estimate
#>    <chr>    <chr>    <chr>          <dbl>
#>  1 Fold01   accuracy multiclass     0.726
#>  2 Fold02   accuracy multiclass     0.712
#>  3 Fold03   accuracy multiclass     0.758
#>  4 Fold04   accuracy multiclass     0.712
#>  5 Fold05   accuracy multiclass     0.712
#>  6 Fold06   accuracy multiclass     0.697
#>  7 Fold07   accuracy multiclass     0.675
#>  8 Fold08   accuracy multiclass     0.721
#>  9 Fold09   accuracy multiclass     0.673
#> 10 Fold10   accuracy multiclass     0.699
#> 11 Fold01   kap      multiclass     0.533
#> 12 Fold02   kap      multiclass     0.512
#> 13 Fold03   kap      multiclass     0.594
#> 14 Fold04   kap      multiclass     0.511
#> 15 Fold05   kap      multiclass     0.514
#> 16 Fold06   kap      multiclass     0.486
#> 17 Fold07   kap      multiclass     0.454
#> 18 Fold08   kap      multiclass     0.531
#> 19 Fold09   kap      multiclass     0.454
#> 20 Fold10   kap      multiclass     0.492

Metrics

Below is a table of all of the metrics available in yardstick, grouped by type.

type metric
class accuracy()
bal_accuracy()
detection_prevalence()
f_meas()
j_index()
kap()
mcc()
npv()
ppv()
precision()
recall()
sens()
spec()
class prob gain_capture()
mn_log_loss()
pr_auc()
roc_auc()
numeric ccc()
huber_loss()
huber_loss_pseudo()
mae()
mape()
mase()
rmse()
rpd()
rpiq()
rsq()
rsq_trad()
smape()