πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

NLP Data Augmentation

Production NLPText Data Augmentation🟒 Free Lesson

Advertisement

NLP Data Augmentation

Data augmentation increases training data diversity without collecting new examples, improving model generalization and robustness.

Augmentation Strategy Overview

MethodTypeDiversityRiskBest For
Synonym replacementLexicalLowLowClassification
Random insertionLexicalMediumMediumRobustness
Back-translationContextualHighLowGeneral
LLM generationContextualVery highMediumLow-resource
Easy data mixingInterpolationHighLowEmbedding models

Back-Translation

Back-translation translates text to a target language and back, producing paraphrased versions.

from transformers import MarianMTModel, MarianTokenizer
import random

class BackTranslator:
    def __init__(self, pivot_lang="de"):
        self.pivot_lang = pivot_lang

        # Forward translation (en -> pivot)
        fwd_model_name = f"Helsinki-NLP/opus-mt-en-{pivot_lang}"
        self.fwd_tokenizer = MarianTokenizer.from_pretrained(fwd_model_name)
        self.fwd_model = MarianMTModel.from_pretrained(fwd_model_name)

        # Back translation (pivot -> en)
        bwd_model_name = f"Helsinki-NLP/opus-mt-{pivot_lang}-en"
        self.bwd_tokenizer = MarianTokenizer.from_pretrained(bwd_model_name)
        self.bwd_model = MarianMTModel.from_pretrained(bwd_model_name)

    def translate(self, text, tokenizer, model):
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        translated = model.generate(**inputs)
        return tokenizer.decode(translated[0], skip_special_tokens=True)

    def augment(self, text):
        """Translate to pivot language and back."""
        translated = self.translate(text, self.fwd_tokenizer, self.fwd_model)
        back_translated = self.translate(translated, self.bwd_tokenizer, self.bwd_model)
        return back_translated

    def augment_dataset(self, texts, num_augmentations=2):
        """Augment a dataset using back-translation."""
        augmented_texts = []
        for text in texts:
            augmented_texts.append(text)  # Keep original
            for _ in range(num_augmentations):
                aug_text = self.augment(text)
                if aug_text.lower() != text.lower():
                    augmented_texts.append(aug_text)
        return augmented_texts

# Usage
bt = BackTranslator(pivot_lang="de")
original = "The quick brown fox jumps over the lazy dog."
augmented = bt.augment(original)
print(f"Original: {original}")
print(f"Augmented: {augmented}")

Lexical Augmentation

Word-level transformations that preserve meaning while increasing diversity.

import random
import nltk
from nltk.corpus import wordnet

nltk.download("wordnet", quiet=True)

class LexicalAugmenter:
    def __init__(self, synonym_prob=0.1, insert_prob=0.1, swap_prob=0.1, delete_prob=0.1):
        self.synonym_prob = synonym_prob
        self.insert_prob = insert_prob
        self.swap_prob = swap_prob
        self.delete_prob = delete_prob

    def get_synonyms(self, word):
        synonyms = set()
        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                if lemma.name().lower() != word.lower():
                    synonyms.add(lemma.name().replace("_", " "))
        return list(synonyms)

    def synonym_replacement(self, text):
        words = text.split()
        new_words = words.copy()
        random_indices = random.sample(range(len(words)), max(1, int(len(words) * self.synonym_prob)))

        for idx in random_indices:
            synonyms = self.get_synonyms(words[idx])
            if synonyms:
                new_words[idx] = random.choice(synonyms)

        return " ".join(new_words)

    def random_insertion(self, text):
        words = text.split()
        num_insert = max(1, int(len(words) * self.insert_prob))

        for _ in range(num_insert):
            random_word = random.choice(words)
            synonyms = self.get_synonyms(random_word)
            if synonyms:
                insert_word = random.choice(synonyms)
                insert_pos = random.randint(0, len(words))
                words.insert(insert_pos, insert_word)

        return " ".join(words)

    def random_swap(self, text):
        words = text.split()
        num_swaps = max(1, int(len(words) * self.swap_prob))

        for _ in range(num_swaps):
            if len(words) >= 2:
                idx1, idx2 = random.sample(range(len(words)), 2)
                words[idx1], words[idx2] = words[idx2], words[idx1]

        return " ".join(words)

    def random_deletion(self, text):
        words = text.split()
        if len(words) <= 1:
            return text

        new_words = [w for w in words if random.random() > self.delete_prob]
        return " ".join(new_words) if new_words else random.choice(words)

    def augment(self, text, num_augmentations=4):
        methods = [
            self.synonym_replacement,
            self.random_insertion,
            self.random_swap,
            self.random_deletion,
        ]

        augmented = [text]
        for _ in range(num_augmentations):
            method = random.choice(methods)
            aug_text = method(text)
            if aug_text != text:
                augmented.append(aug_text)

        return augmented

# Usage
augmenter = LexicalAugmenter()
text = "The movie was absolutely fantastic and entertaining."
augmented_texts = augmenter.augment(text, num_augmentations=3)
for t in augmented_texts:
    print(f"  - {t}")

LLM-Based Augmentation

Using large language models to generate diverse, high-quality augmented examples.

DfLLM Augmentation Quality

The quality of LLM-augmented data depends on:

Q(a)=Ξ±β‹…Fluency(a)+Ξ²β‹…Consistency(a,l)+Ξ³β‹…Diversity(a,D)Q(a) = \alpha \cdot \text{Fluency}(a) + \beta \cdot \text{Consistency}(a, l) + \gamma \cdot \text{Diversity}(a, D)

where aa is the augmented sample, ll is the label, and DD is the existing dataset.

import openai

class LLMAugmenter:
    def __init__(self, model="gpt-4", temperature=0.8):
        self.model = model
        self.temperature = temperature

    def generate_paraphrases(self, text, n=5):
        prompt = f"""Generate {n} diverse paraphrases of the following text while preserving its meaning:

Original: {text}

Paraphrases:"""

        response = openai.ChatCompletion.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=self.temperature,
            n=1,
        )

        paraphrases = response.choices[0].message.content.strip().split("\n")
        return [p.strip().lstrip("0123456789. ") for p in paraphrases if p.strip()]

    def generate_variations(self, text, label, n=5):
        prompt = f"""Given a text and its label, generate {n} variations with the same label:

Text: {text}
Label: {label}

Generate variations that are:
1. Semantically similar but use different wording
2. Slightly different in length
3. Diverse in expression style

Variations:"""

        response = openai.ChatCompletion.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=self.temperature,
        )

        variations = response.choices[0].message.content.strip().split("\n")
        return [
            {"text": v.strip().lstrip("0123456789. "), "label": label}
            for v in variations if v.strip()
        ]

    def augment_few_shot(self, examples_per_class, target_class):
        prompt = f"""Generate {examples_per_class} training examples for class "{target_class}":

Examples:"""

        response = openai.ChatCompletion.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.9,
        )

        return response.choices[0].message.content.strip()

# Usage
augmenter = LLMAugmenter()
original = "This product exceeded my expectations in every way."
paraphrases = augmenter.generate_paraphrases(original, n=3)
print("Original:", original)
for i, p in enumerate(paraphrases, 1):
    print(f"Paraphrase {i}: {p}")

Text Mixing Strategies

Interpolation between training examples in embedding space.

DfMixup for Text

Given two text examples (xi,yi)(x_i, y_i) and (xj,yj)(x_j, y_j), Mixup creates:

x~=Ξ»β‹…Embed(xi)+(1βˆ’Ξ»)β‹…Embed(xj)\tilde{x} = \lambda \cdot \text{Embed}(x_i) + (1-\lambda) \cdot \text{Embed}(x_j)
y~=Ξ»β‹…yi+(1βˆ’Ξ»)β‹…yj\tilde{y} = \lambda \cdot y_i + (1-\lambda) \cdot y_j

where λ∼Beta(α,α)\lambda \sim \text{Beta}(\alpha, \alpha) controls the interpolation ratio.

import torch
import torch.nn.functional as F

class TextMixup:
    def __init__(self, tokenizer, model, alpha=0.2):
        self.tokenizer = tokenizer
        self.model = model
        self.alpha = alpha

    def get_embeddings(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model.get_input_embeddings()(**inputs)
        return outputs.last_hidden_state.mean(dim=1)

    def mixup_batch(self, batch_texts, batch_labels):
        """Apply Mixup to a batch of texts."""
        embeddings = torch.stack([self.get_embeddings(t).squeeze() for t in batch_texts])

        lam = torch.distributions.Beta(self.alpha, self.alpha).sample()
        indices = torch.randperm(len(batch_texts))

        mixed_embeddings = lam * embeddings + (1 - lam) * embeddings[indices]
        mixed_labels = lam * batch_labels + (1 - lam) * batch_labels[indices]

        return mixed_embeddings, mixed_labels

    def cutout_tokens(self, text, num_tokens=2):
        """Randomly remove tokens from the text."""
        tokens = text.split()
        if len(tokens) <= num_tokens:
            return text

        cut_indices = sorted(random.sample(range(len(tokens)), num_tokens), reverse=True)
        for idx in cut_indices:
            tokens.pop(idx)

        return " ".join(tokens)

    def token_replacement(self, text, replacement_prob=0.15):
        """Replace tokens with [MASK] or random tokens."""
        tokens = text.split()
        vocab = list(set(t for example in training_data for t in example.split()))

        new_tokens = []
        for token in tokens:
            if random.random() < replacement_prob:
                if random.random() < 0.5:
                    new_tokens.append("[MASK]")
                else:
                    new_tokens.append(random.choice(vocab))
            else:
                new_tokens.append(token)

        return " ".join(new_tokens)

Augmentation Impact

TaskBaseline+ Back-Translation+ LLM Aug+ Mixup
Sentiment (SST-2)93.193.894.294.0
Topic (AG News)94.595.195.495.2
NLI (SNLI)88.789.489.889.1
Low-resource (100 samples)72.378.681.279.4

Best Practices

  1. Validate augmented quality - Ensure augmented samples are correct and diverse
  2. Balance augmentation ratio - Typically 2-5x augmentation works well
  3. Task-appropriate methods - Back-translation for general, LLM for low-resource
  4. Monitor for noise - Aggressive augmentation can hurt performance
  5. Combine strategies - Multiple augmentation methods often complement each other

Key Takeaways

  • Back-translation is the most reliable augmentation method across tasks
  • LLM-based augmentation provides highest diversity but requires quality control
  • Lexical methods are simple and fast but limited in diversity
  • Text mixing improves calibration and robustness
  • Augmentation is most impactful in low-resource settings
⭐

Premium Content

NLP Data Augmentation

Unlock this lesson and 900+ advanced tutorials with a Premium plan.

🎯End-to-end Projects
πŸ’ΌInterview Prep
πŸ“œCertificates
🀝Community Access

Already a member? Log in

Need Expert NLP Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement