---
title: "Using a Bert model to get the predictability of words in their context"
bibliography: '`r system.file("REFERENCES.bib", package="pangoling")`'
vignette: >
  %\VignetteIndexEntry{Using a Bert model to get the predictability of words in their context}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---




<!-- https://ropensci.org/blog/2019/12/08/precompute-vignettes/ -->

Whereas [the vignette about GPT-2](intro-gpt2.html) presents a very popular way 
to calculate word probabilities using GPT-like models, masked models present an
alternative, especially, when we just care about the final word following a 
certain context.

A masked language model (also called BERT-like, or encoder model) is a type of 
large language model  that can be used to predict the content of a mask in a 
sentence. BERT is an example of a masked language model 
[see also @Devlinetal2018].


First load the following packages:


``` r
library(pangoling)
library(tidytable) # fast alternative to dplyr
```


Notice the following potential pitfall. This would be a **bad** approach for
making predictions in a masked model:


``` r
masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK]")
#> Processing using masked model 'bert-base-uncased/' ...
#> # A tidytable: 30,522 × 4
#>    masked_sentence                            token     pred mask_n
#>    <chr>                                      <chr>    <dbl>  <int>
#>  1 The apple doesn't fall far from the [MASK] .      -0.0579      1
#>  2 The apple doesn't fall far from the [MASK] ;      -3.21        1
#>  3 The apple doesn't fall far from the [MASK] !      -4.83        1
#>  4 The apple doesn't fall far from the [MASK] ?      -5.33        1
#>  5 The apple doesn't fall far from the [MASK] ...    -7.84        1
#>  6 The apple doesn't fall far from the [MASK] |      -8.11        1
#>  7 The apple doesn't fall far from the [MASK] tree   -8.76        1
#>  8 The apple doesn't fall far from the [MASK] -      -9.69        1
#>  9 The apple doesn't fall far from the [MASK] '      -9.87        1
#> 10 The apple doesn't fall far from the [MASK] :     -10.5         1
#> # ℹ 30,512 more rows
```
(The pretrained models and tokenizers will be downloaded from 
https://huggingface.co/ the first time they are used.)


The most common predictions are punctuation marks, because BERT uses the left 
*and* right context. In this case, the right context indicates that the mask is 
the final *token* of the sentence.
More expected results are obtained in the following way:


``` r
masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK].")
#> Processing using masked model 'bert-base-uncased/' ...
#> # A tidytable: 30,522 × 4
#>    masked_sentence                             token     pred mask_n
#>    <chr>                                       <chr>    <dbl>  <int>
#>  1 The apple doesn't fall far from the [MASK]. tree    -0.691      1
#>  2 The apple doesn't fall far from the [MASK]. ground  -1.98       1
#>  3 The apple doesn't fall far from the [MASK]. sky     -2.13       1
#>  4 The apple doesn't fall far from the [MASK]. table   -4.02       1
#>  5 The apple doesn't fall far from the [MASK]. floor   -4.31       1
#>  6 The apple doesn't fall far from the [MASK]. top     -4.48       1
#>  7 The apple doesn't fall far from the [MASK]. ceiling -4.62       1
#>  8 The apple doesn't fall far from the [MASK]. window  -4.87       1
#>  9 The apple doesn't fall far from the [MASK]. trees   -4.94       1
#> 10 The apple doesn't fall far from the [MASK]. apple   -4.95       1
#> # ℹ 30,512 more rows
```

We can mask several tokens as well (but bear in mind that this type of models 
are trained with only 10-15% of masks):


``` r
df_masks <- 
  masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK][MASK]")
#> Processing using masked model 'bert-base-uncased/' ...
df_masks |> filter(mask_n == 1)
#> # A tidytable: 30,522 × 4
#>    masked_sentence                                  token     pred mask_n
#>    <chr>                                            <chr>    <dbl>  <int>
#>  1 The apple doesn't fall far from the [MASK][MASK] tree    -0.738      1
#>  2 The apple doesn't fall far from the [MASK][MASK] ground  -1.72       1
#>  3 The apple doesn't fall far from the [MASK][MASK] sky     -2.31       1
#>  4 The apple doesn't fall far from the [MASK][MASK] table   -3.67       1
#>  5 The apple doesn't fall far from the [MASK][MASK] floor   -4.47       1
#>  6 The apple doesn't fall far from the [MASK][MASK] top     -4.67       1
#>  7 The apple doesn't fall far from the [MASK][MASK] ceiling -4.89       1
#>  8 The apple doesn't fall far from the [MASK][MASK] window  -5.02       1
#>  9 The apple doesn't fall far from the [MASK][MASK] bush    -5.02       1
#> 10 The apple doesn't fall far from the [MASK][MASK] vine    -5.03       1
#> # ℹ 30,512 more rows
df_masks |> filter(mask_n == 2)
#> # A tidytable: 30,522 × 4
#>    masked_sentence                                  token     pred mask_n
#>    <chr>                                            <chr>    <dbl>  <int>
#>  1 The apple doesn't fall far from the [MASK][MASK] .      -0.0570      2
#>  2 The apple doesn't fall far from the [MASK][MASK] ;      -2.91        2
#>  3 The apple doesn't fall far from the [MASK][MASK] !      -7.33        2
#>  4 The apple doesn't fall far from the [MASK][MASK] ?      -9.09        2
#>  5 The apple doesn't fall far from the [MASK][MASK] ...   -11.9         2
#>  6 The apple doesn't fall far from the [MASK][MASK] ,     -12.4         2
#>  7 The apple doesn't fall far from the [MASK][MASK] -     -12.8         2
#>  8 The apple doesn't fall far from the [MASK][MASK] |     -13.3         2
#>  9 The apple doesn't fall far from the [MASK][MASK] so    -13.4         2
#> 10 The apple doesn't fall far from the [MASK][MASK] :     -13.9         2
#> # ℹ 30,512 more rows
```

We can also use BERT to examine the predictability of words assuming that both 
the left and right contexts are known:


``` r
(df_sent <- data.frame(
  left = c("The", "The"),
  critical = c("apple", "pear"),
  right = c(
    "doesn't fall far from the tree.",
    "doesn't fall far from the tree."
  )
))
#>   left critical                           right
#> 1  The    apple doesn't fall far from the tree.
#> 2  The     pear doesn't fall far from the tree.
```

The function `masked_targets_pred()` will give us the log-probability of the 
target word (and will take care of summing the log-probabilities in case the 
target is composed by several tokens).


``` r
df_sent <- df_sent %>%
  mutate(lp = masked_targets_pred(
    prev_contexts = left,
    targets = critical,
    after_contexts = right
  ))
#> Processing using masked model 'bert-base-uncased/' ...
#> Processing 1 batch(es) of 13 tokens.
#> The [apple] doesn't fall far from the tree.
#> Processing 1 batch(es) of 13 tokens.
#> The [pear] doesn't fall far from the tree.
#> ***
df_sent
#> # A tidytable: 2 × 4
#>   left  critical right                              lp
#>   <chr> <chr>    <chr>                           <dbl>
#> 1 The   apple    doesn't fall far from the tree. -4.68
#> 2 The   pear     doesn't fall far from the tree. -8.60
```

As expected (given the popularity of the proverb), "apple" is a more likely 
target word than "pear".


# References