Text Generation
Text generation is the process of producing coherent text token-by-token using language models. The choice of decoding strategy significantly impacts output quality and diversity.
Decoding Strategies Overview
| Strategy | Diversity | Quality | Speed | Use Case |
|---|---|---|---|---|
| Greedy | Low | High | Fast | Translation, factual QA |
| Beam search | Low | High | Medium | Summarization, MT |
| Top-k | Medium | Medium | Fast | Creative writing |
| Top-p | Medium | High | Fast | General generation |
| Temperature | Variable | Variable | Fast | All tasks |
Temperature Sampling
Temperature controls the randomness of predictions by scaling logits before softmax.
DfTemperature Scaling
Where T is the temperature parameter:
T β 0: Approaches greedy decoding (deterministic)T = 1: Standard softmax (no change)T > 1: More random, flatter distributionT < 1: More confident, sharper distribution
import torch
import torch.nn.functional as F
def temperature_sample(logits, temperature=1.0):
if temperature == 0:
return torch.argmax(logits, dim=-1)
scaled_logits = logits / temperature
probs = F.softmax(scaled_logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
# Example
logits = torch.tensor([2.0, 1.0, 0.5, 0.1])
for temp in [0.1, 0.5, 1.0, 2.0]:
probs = F.softmax(logits / temp, dim=-1)
print(f"T={temp}: {probs.numpy().round(3)}")
# T=0.1: [0.982, 0.012, 0.004, 0.002]
# T=0.5: [0.597, 0.218, 0.114, 0.071]
# T=1.0: [0.467, 0.171, 0.104, 0.070]
# T=2.0: [0.378, 0.233, 0.180, 0.150]
Top-k Sampling
Top-k restricts sampling to the k most probable tokens.
DfTop-k Sampling
def top_k_sample(logits, k=50, temperature=1.0):
# Apply temperature
scaled_logits = logits / temperature
# Get top-k values and indices
top_k_values, top_k_indices = torch.topk(scaled_logits, k)
# Create mask for non-top-k tokens
mask = torch.full_like(scaled_logits, float('-inf'))
mask.scatter_(1, top_k_indices, top_k_values)
# Sample from filtered distribution
probs = F.softmax(mask, dim=-1)
return torch.multinomial(probs, num_samples=1)
# Example
logits = torch.randn(1, 1000) # 1000 vocab tokens
token = top_k_sample(logits, k=50, temperature=0.8)
print(f"Sampled token ID: {token.item()}")
Top-p (Nucleus) Sampling
Top-p dynamically selects the smallest set of tokens whose cumulative probability exceeds p.
DfTop-p (Nucleus) Sampling
def top_p_sample(logits, p=0.9, temperature=1.0):
scaled_logits = logits / temperature
sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_mask = cumulative_probs - sorted_probs > p
sorted_logits[sorted_mask] = float('-inf')
probs = F.softmax(sorted_logits, dim=-1)
return sorted_indices.gather(-1, torch.multinomial(probs, 1))
# Example
logits = torch.randn(1, 1000)
token = top_p_sample(logits, p=0.92, temperature=0.9)
print(f"Sampled token ID: {token.item()}")
Beam Search
Beam search maintains k most probable sequences at each step.
DfBeam Search Score
Where Ξ± is a length normalization factor.
def beam_search(model, tokenizer, prompt, num_beams=5, max_length=50, length_penalty=1.0):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
outputs = model.generate(
input_ids,
max_length=max_length,
num_beams=num_beams,
length_penalty=length_penalty,
early_stopping=True,
no_repeat_ngram_size=2,
return_dict_in_generate=True,
)
best_sequence = outputs.sequences[0]
return tokenizer.decode(best_sequence, skip_special_tokens=True)
# Usage
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
result = beam_search(model, tokenizer, "The future of AI is", num_beams=5)
print(result)
Repetition Penalty
DfRepetition Penalty
def generate_with_penalty(model, tokenizer, prompt, repetition_penalty=1.2):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
output = model.generate(
input_ids,
max_length=100,
temperature=0.8,
top_p=0.9,
repetition_penalty=repetition_penalty,
do_sample=True,
)
return tokenizer.decode(output[0], skip_special_tokens=True)
Decoding Strategy Comparison
| Strategy | Output Type | Best For | Drawback |
|---|---|---|---|
| Greedy | Deterministic | Factual tasks | Repetitive |
| Beam search | Near-deterministic | Translation, summarization | Can be generic |
| Top-k | Stochastic | Creative writing | Fixed candidate pool |
| Top-p | Stochastic | General generation | Dynamic pool size |
| Typical | Stochastic | Diverse generation | Complex tuning |
| Min-p | Stochastic | Balanced quality/diversity | Newer method |
Top-p vs Top-k Comparison
Speculative Decoding
Speculative decoding uses a smaller "draft" model to generate candidate tokens, then verifies them with the large model in parallel.
def speculative_decode(draft_model, target_model, prompt, num_speculative=5):
draft_tokens = draft_model.generate(
prompt, max_new_tokens=num_speculative, do_sample=False
)
# Verify all draft tokens in parallel with target model
target_output = target_model(draft_tokens)
target_probs = F.softmax(target_output.logits, dim=-1)
# Accept/reject each token
accepted = []
for i in range(num_speculative):
draft_prob = draft_probs[i]
target_prob = target_probs[i]
if torch.rand(1) < min(1, target_prob / draft_prob):
accepted.append(draft_tokens[i])
else:
break
return accepted
Speculative decoding can achieve 2-3Γ speedup for autoregressive generation by batching verification across multiple draft tokens, since verification is memory-bandwidth bound rather than compute bound.