BERT Model
BERT (Bidirectional Encoder Representations from Transformers) introduced bidirectional pre-training for language representations, achieving state-of-the-art on 11 NLP tasks at launch.
BERT Variants
| Model | Layers | Hidden | Heads | Parameters | Training |
|---|---|---|---|---|---|
| BERT-Base | 12 | 768 | 12 | 110M | 16 TPUs |
| BERT-Large | 24 | 1024 | 16 | 340M | 64 TPUs |
| RoBERTa-Base | 12 | 768 | 12 | 125M | 1024 GPUs |
| RoBERTa-Large | 24 | 1024 | 16 | 355M | 1024 GPUs |
| ALBERT-xxlarge | 12 | 4096 | 64 | 235M | β |
Pre-Training Objectives
Masked Language Modeling (MLM)
BERT randomly masks 15% of input tokens and trains the model to predict them.
DfMasked Language Modeling Objective
Where M is the set of masked positions and x_{\M} represents the unmasked context.
Masking strategy:
- 80% of selected tokens are replaced with
[MASK] - 10% are replaced with a random token
- 10% remain unchanged
import random
def create_mlm_batch(input_ids, vocab_size, mask_token_id=103):
labels = input_ids.clone()
probability_matrix = torch.full(input_ids.shape, 0.15)
# Don't mask special tokens
special_tokens_mask = (input_ids == 101) | (input_ids == 102) | (input_ids == 0)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # Only compute loss on masked tokens
# 80% mask, 10% random, 10% keep
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = mask_token_id
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, input_ids.shape)
input_ids[indices_random] = random_words[indices_random]
return input_ids, labels
Next Sentence Prediction (NSP)
DfNext Sentence Prediction
The model learns to predict whether sentence B follows sentence A.
# NSP training data preparation
def create_nsp_example(sentence_a, sentence_b, is_next=True):
tokens = ['[CLS]'] + sentence_a + ['[SEP]'] + sentence_b + ['[SEP]']
segment_ids = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1)
label = 1 if is_next else 0
return tokens, segment_ids, label
Fine-Tuning for Downstream Tasks
Token Classification (NER, POS)
from transformers import BertForTokenClassification, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=9)
def tokenize_and_align_labels(text, labels, max_length=128):
tokenized = tokenizer(text, is_split_into_words=True,
padding='max_length', truncation=True,
max_length=max_length, return_tensors='pt')
aligned_labels = []
word_ids = tokenized.word_ids()
previous_word_id = None
for word_id in word_ids:
if word_id is None:
aligned_labels.append(-100)
elif word_id != previous_word_id:
aligned_labels.append(labels[word_id])
else:
aligned_labels.append(-100) # Subword continuation
previous_word_id = word_id
tokenized['labels'] = torch.tensor([aligned_labels])
return tokenized
Sequence Classification
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained(
'bert-base-uncased', num_labels=2
)
# Forward pass
inputs = tokenizer("This movie is great!", return_tensors="pt")
outputs = model(**inputs, labels=torch.tensor([1]))
loss = outputs.loss
logits = outputs.logits
Question Answering
from transformers import BertForQuestionAnswering
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
context = "BERT was developed by Google. It was released in 2018."
question = "Who developed BERT?"
inputs = tokenizer(question, context, return_tensors="pt")
outputs = model(**inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits) + 1
answer = tokenizer.decode(inputs['input_ids'][0][start_idx:end_idx])
print(f"Answer: {answer}") # "google"
BERT vs GPT vs ELMo
| Feature | BERT | GPT | ELMo |
|---|---|---|---|
| Architecture | Encoder only | Decoder only | BiLSTM |
| Direction | Bidirectional | Left-to-right | Bidirectional |
| Pre-training | MLM + NSP | Language modeling | Language modeling |
| Fine-tuning | Add task head | Prompting | Feature extraction |
| Context window | 512 tokens | Variable | Variable |
| Best for | Understanding | Generation | Feature extraction |
BERT's bidirectional nature makes it unsuitable for text generation but superior for understanding tasks, as it can leverage both left and right context simultaneously.
Attention Patterns in BERT
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel
def visualize_attention(text, layer=0, head=0):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
attention = outputs.attentions[layer][0, head].detach().numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
plt.figure(figsize=(10, 8))
sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens, cmap='viridis')
plt.title(f"Layer {layer}, Head {head}")
plt.show()
visualize_attention("The cat sat on the mat")