--- title: "Using a Bert model to get the predictability of words in their context" bibliography: '`r system.file("REFERENCES.bib", package="pangoling")`' vignette: > %\VignetteIndexEntry{Using a Bert model to get the predictability of words in their context} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- <!-- https://ropensci.org/blog/2019/12/08/precompute-vignettes/ --> Whereas [the vignette about GPT-2](intro-gpt2.html) presents a very popular way to calculate word probabilities using GPT-like models, masked models present an alternative, especially, when we just care about the final word following a certain context. A masked language model (also called BERT-like, or encoder model) is a type of large language model that can be used to predict the content of a mask in a sentence. BERT is an example of a masked language model [see also @Devlinetal2018]. First load the following packages: ``` r library(pangoling) library(tidytable) # fast alternative to dplyr ``` Notice the following potential pitfall. This would be a **bad** approach for making predictions in a masked model: ``` r masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK]") #> Processing using masked model 'bert-base-uncased/' ... #> # A tidytable: 30,522 × 4 #> masked_sentence token pred mask_n #> <chr> <chr> <dbl> <int> #> 1 The apple doesn't fall far from the [MASK] . -0.0579 1 #> 2 The apple doesn't fall far from the [MASK] ; -3.21 1 #> 3 The apple doesn't fall far from the [MASK] ! -4.83 1 #> 4 The apple doesn't fall far from the [MASK] ? -5.33 1 #> 5 The apple doesn't fall far from the [MASK] ... -7.84 1 #> 6 The apple doesn't fall far from the [MASK] | -8.11 1 #> 7 The apple doesn't fall far from the [MASK] tree -8.76 1 #> 8 The apple doesn't fall far from the [MASK] - -9.69 1 #> 9 The apple doesn't fall far from the [MASK] ' -9.87 1 #> 10 The apple doesn't fall far from the [MASK] : -10.5 1 #> # ℹ 30,512 more rows ``` (The pretrained models and tokenizers will be downloaded from https://huggingface.co/ the first time they are used.) The most common predictions are punctuation marks, because BERT uses the left *and* right context. In this case, the right context indicates that the mask is the final *token* of the sentence. More expected results are obtained in the following way: ``` r masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK].") #> Processing using masked model 'bert-base-uncased/' ... #> # A tidytable: 30,522 × 4 #> masked_sentence token pred mask_n #> <chr> <chr> <dbl> <int> #> 1 The apple doesn't fall far from the [MASK]. tree -0.691 1 #> 2 The apple doesn't fall far from the [MASK]. ground -1.98 1 #> 3 The apple doesn't fall far from the [MASK]. sky -2.13 1 #> 4 The apple doesn't fall far from the [MASK]. table -4.02 1 #> 5 The apple doesn't fall far from the [MASK]. floor -4.31 1 #> 6 The apple doesn't fall far from the [MASK]. top -4.48 1 #> 7 The apple doesn't fall far from the [MASK]. ceiling -4.62 1 #> 8 The apple doesn't fall far from the [MASK]. window -4.87 1 #> 9 The apple doesn't fall far from the [MASK]. trees -4.94 1 #> 10 The apple doesn't fall far from the [MASK]. apple -4.95 1 #> # ℹ 30,512 more rows ``` We can mask several tokens as well (but bear in mind that this type of models are trained with only 10-15% of masks): ``` r df_masks <- masked_tokens_pred_tbl("The apple doesn't fall far from the [MASK][MASK]") #> Processing using masked model 'bert-base-uncased/' ... df_masks |> filter(mask_n == 1) #> # A tidytable: 30,522 × 4 #> masked_sentence token pred mask_n #> <chr> <chr> <dbl> <int> #> 1 The apple doesn't fall far from the [MASK][MASK] tree -0.738 1 #> 2 The apple doesn't fall far from the [MASK][MASK] ground -1.72 1 #> 3 The apple doesn't fall far from the [MASK][MASK] sky -2.31 1 #> 4 The apple doesn't fall far from the [MASK][MASK] table -3.67 1 #> 5 The apple doesn't fall far from the [MASK][MASK] floor -4.47 1 #> 6 The apple doesn't fall far from the [MASK][MASK] top -4.67 1 #> 7 The apple doesn't fall far from the [MASK][MASK] ceiling -4.89 1 #> 8 The apple doesn't fall far from the [MASK][MASK] window -5.02 1 #> 9 The apple doesn't fall far from the [MASK][MASK] bush -5.02 1 #> 10 The apple doesn't fall far from the [MASK][MASK] vine -5.03 1 #> # ℹ 30,512 more rows df_masks |> filter(mask_n == 2) #> # A tidytable: 30,522 × 4 #> masked_sentence token pred mask_n #> <chr> <chr> <dbl> <int> #> 1 The apple doesn't fall far from the [MASK][MASK] . -0.0570 2 #> 2 The apple doesn't fall far from the [MASK][MASK] ; -2.91 2 #> 3 The apple doesn't fall far from the [MASK][MASK] ! -7.33 2 #> 4 The apple doesn't fall far from the [MASK][MASK] ? -9.09 2 #> 5 The apple doesn't fall far from the [MASK][MASK] ... -11.9 2 #> 6 The apple doesn't fall far from the [MASK][MASK] , -12.4 2 #> 7 The apple doesn't fall far from the [MASK][MASK] - -12.8 2 #> 8 The apple doesn't fall far from the [MASK][MASK] | -13.3 2 #> 9 The apple doesn't fall far from the [MASK][MASK] so -13.4 2 #> 10 The apple doesn't fall far from the [MASK][MASK] : -13.9 2 #> # ℹ 30,512 more rows ``` We can also use BERT to examine the predictability of words assuming that both the left and right contexts are known: ``` r (df_sent <- data.frame( left = c("The", "The"), critical = c("apple", "pear"), right = c( "doesn't fall far from the tree.", "doesn't fall far from the tree." ) )) #> left critical right #> 1 The apple doesn't fall far from the tree. #> 2 The pear doesn't fall far from the tree. ``` The function `masked_targets_pred()` will give us the log-probability of the target word (and will take care of summing the log-probabilities in case the target is composed by several tokens). ``` r df_sent <- df_sent %>% mutate(lp = masked_targets_pred( prev_contexts = left, targets = critical, after_contexts = right )) #> Processing using masked model 'bert-base-uncased/' ... #> Processing 1 batch(es) of 13 tokens. #> The [apple] doesn't fall far from the tree. #> Processing 1 batch(es) of 13 tokens. #> The [pear] doesn't fall far from the tree. #> *** df_sent #> # A tidytable: 2 × 4 #> left critical right lp #> <chr> <chr> <chr> <dbl> #> 1 The apple doesn't fall far from the tree. -4.68 #> 2 The pear doesn't fall far from the tree. -8.60 ``` As expected (given the popularity of the proverb), "apple" is a more likely target word than "pear". # References