NLP Data Augmentation
Data augmentation increases training data diversity without collecting new examples, improving model generalization and robustness.
Augmentation Strategy Overview
| Method | Type | Diversity | Risk | Best For |
|---|---|---|---|---|
| Synonym replacement | Lexical | Low | Low | Classification |
| Random insertion | Lexical | Medium | Medium | Robustness |
| Back-translation | Contextual | High | Low | General |
| LLM generation | Contextual | Very high | Medium | Low-resource |
| Easy data mixing | Interpolation | High | Low | Embedding models |
Back-Translation
Back-translation translates text to a target language and back, producing paraphrased versions.
from transformers import MarianMTModel, MarianTokenizer
import random
class BackTranslator:
def __init__(self, pivot_lang="de"):
self.pivot_lang = pivot_lang
# Forward translation (en -> pivot)
fwd_model_name = f"Helsinki-NLP/opus-mt-en-{pivot_lang}"
self.fwd_tokenizer = MarianTokenizer.from_pretrained(fwd_model_name)
self.fwd_model = MarianMTModel.from_pretrained(fwd_model_name)
# Back translation (pivot -> en)
bwd_model_name = f"Helsinki-NLP/opus-mt-{pivot_lang}-en"
self.bwd_tokenizer = MarianTokenizer.from_pretrained(bwd_model_name)
self.bwd_model = MarianMTModel.from_pretrained(bwd_model_name)
def translate(self, text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
translated = model.generate(**inputs)
return tokenizer.decode(translated[0], skip_special_tokens=True)
def augment(self, text):
"""Translate to pivot language and back."""
translated = self.translate(text, self.fwd_tokenizer, self.fwd_model)
back_translated = self.translate(translated, self.bwd_tokenizer, self.bwd_model)
return back_translated
def augment_dataset(self, texts, num_augmentations=2):
"""Augment a dataset using back-translation."""
augmented_texts = []
for text in texts:
augmented_texts.append(text) # Keep original
for _ in range(num_augmentations):
aug_text = self.augment(text)
if aug_text.lower() != text.lower():
augmented_texts.append(aug_text)
return augmented_texts
# Usage
bt = BackTranslator(pivot_lang="de")
original = "The quick brown fox jumps over the lazy dog."
augmented = bt.augment(original)
print(f"Original: {original}")
print(f"Augmented: {augmented}")
Lexical Augmentation
Word-level transformations that preserve meaning while increasing diversity.
import random
import nltk
from nltk.corpus import wordnet
nltk.download("wordnet", quiet=True)
class LexicalAugmenter:
def __init__(self, synonym_prob=0.1, insert_prob=0.1, swap_prob=0.1, delete_prob=0.1):
self.synonym_prob = synonym_prob
self.insert_prob = insert_prob
self.swap_prob = swap_prob
self.delete_prob = delete_prob
def get_synonyms(self, word):
synonyms = set()
for syn in wordnet.synsets(word):
for lemma in syn.lemmas():
if lemma.name().lower() != word.lower():
synonyms.add(lemma.name().replace("_", " "))
return list(synonyms)
def synonym_replacement(self, text):
words = text.split()
new_words = words.copy()
random_indices = random.sample(range(len(words)), max(1, int(len(words) * self.synonym_prob)))
for idx in random_indices:
synonyms = self.get_synonyms(words[idx])
if synonyms:
new_words[idx] = random.choice(synonyms)
return " ".join(new_words)
def random_insertion(self, text):
words = text.split()
num_insert = max(1, int(len(words) * self.insert_prob))
for _ in range(num_insert):
random_word = random.choice(words)
synonyms = self.get_synonyms(random_word)
if synonyms:
insert_word = random.choice(synonyms)
insert_pos = random.randint(0, len(words))
words.insert(insert_pos, insert_word)
return " ".join(words)
def random_swap(self, text):
words = text.split()
num_swaps = max(1, int(len(words) * self.swap_prob))
for _ in range(num_swaps):
if len(words) >= 2:
idx1, idx2 = random.sample(range(len(words)), 2)
words[idx1], words[idx2] = words[idx2], words[idx1]
return " ".join(words)
def random_deletion(self, text):
words = text.split()
if len(words) <= 1:
return text
new_words = [w for w in words if random.random() > self.delete_prob]
return " ".join(new_words) if new_words else random.choice(words)
def augment(self, text, num_augmentations=4):
methods = [
self.synonym_replacement,
self.random_insertion,
self.random_swap,
self.random_deletion,
]
augmented = [text]
for _ in range(num_augmentations):
method = random.choice(methods)
aug_text = method(text)
if aug_text != text:
augmented.append(aug_text)
return augmented
# Usage
augmenter = LexicalAugmenter()
text = "The movie was absolutely fantastic and entertaining."
augmented_texts = augmenter.augment(text, num_augmentations=3)
for t in augmented_texts:
print(f" - {t}")
LLM-Based Augmentation
Using large language models to generate diverse, high-quality augmented examples.
DfLLM Augmentation Quality
The quality of LLM-augmented data depends on:
where is the augmented sample, is the label, and is the existing dataset.
import openai
class LLMAugmenter:
def __init__(self, model="gpt-4", temperature=0.8):
self.model = model
self.temperature = temperature
def generate_paraphrases(self, text, n=5):
prompt = f"""Generate {n} diverse paraphrases of the following text while preserving its meaning:
Original: {text}
Paraphrases:"""
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
n=1,
)
paraphrases = response.choices[0].message.content.strip().split("\n")
return [p.strip().lstrip("0123456789. ") for p in paraphrases if p.strip()]
def generate_variations(self, text, label, n=5):
prompt = f"""Given a text and its label, generate {n} variations with the same label:
Text: {text}
Label: {label}
Generate variations that are:
1. Semantically similar but use different wording
2. Slightly different in length
3. Diverse in expression style
Variations:"""
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
)
variations = response.choices[0].message.content.strip().split("\n")
return [
{"text": v.strip().lstrip("0123456789. "), "label": label}
for v in variations if v.strip()
]
def augment_few_shot(self, examples_per_class, target_class):
prompt = f"""Generate {examples_per_class} training examples for class "{target_class}":
Examples:"""
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.9,
)
return response.choices[0].message.content.strip()
# Usage
augmenter = LLMAugmenter()
original = "This product exceeded my expectations in every way."
paraphrases = augmenter.generate_paraphrases(original, n=3)
print("Original:", original)
for i, p in enumerate(paraphrases, 1):
print(f"Paraphrase {i}: {p}")
Text Mixing Strategies
Interpolation between training examples in embedding space.
DfMixup for Text
Given two text examples and , Mixup creates:
where controls the interpolation ratio.
import torch
import torch.nn.functional as F
class TextMixup:
def __init__(self, tokenizer, model, alpha=0.2):
self.tokenizer = tokenizer
self.model = model
self.alpha = alpha
def get_embeddings(self, text):
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = self.model.get_input_embeddings()(**inputs)
return outputs.last_hidden_state.mean(dim=1)
def mixup_batch(self, batch_texts, batch_labels):
"""Apply Mixup to a batch of texts."""
embeddings = torch.stack([self.get_embeddings(t).squeeze() for t in batch_texts])
lam = torch.distributions.Beta(self.alpha, self.alpha).sample()
indices = torch.randperm(len(batch_texts))
mixed_embeddings = lam * embeddings + (1 - lam) * embeddings[indices]
mixed_labels = lam * batch_labels + (1 - lam) * batch_labels[indices]
return mixed_embeddings, mixed_labels
def cutout_tokens(self, text, num_tokens=2):
"""Randomly remove tokens from the text."""
tokens = text.split()
if len(tokens) <= num_tokens:
return text
cut_indices = sorted(random.sample(range(len(tokens)), num_tokens), reverse=True)
for idx in cut_indices:
tokens.pop(idx)
return " ".join(tokens)
def token_replacement(self, text, replacement_prob=0.15):
"""Replace tokens with [MASK] or random tokens."""
tokens = text.split()
vocab = list(set(t for example in training_data for t in example.split()))
new_tokens = []
for token in tokens:
if random.random() < replacement_prob:
if random.random() < 0.5:
new_tokens.append("[MASK]")
else:
new_tokens.append(random.choice(vocab))
else:
new_tokens.append(token)
return " ".join(new_tokens)
Augmentation Impact
| Task | Baseline | + Back-Translation | + LLM Aug | + Mixup |
|---|---|---|---|---|
| Sentiment (SST-2) | 93.1 | 93.8 | 94.2 | 94.0 |
| Topic (AG News) | 94.5 | 95.1 | 95.4 | 95.2 |
| NLI (SNLI) | 88.7 | 89.4 | 89.8 | 89.1 |
| Low-resource (100 samples) | 72.3 | 78.6 | 81.2 | 79.4 |
Best Practices
- Validate augmented quality - Ensure augmented samples are correct and diverse
- Balance augmentation ratio - Typically 2-5x augmentation works well
- Task-appropriate methods - Back-translation for general, LLM for low-resource
- Monitor for noise - Aggressive augmentation can hurt performance
- Combine strategies - Multiple augmentation methods often complement each other
Key Takeaways
- Back-translation is the most reliable augmentation method across tasks
- LLM-based augmentation provides highest diversity but requires quality control
- Lexical methods are simple and fast but limited in diversity
- Text mixing improves calibration and robustness
- Augmentation is most impactful in low-resource settings