Instruction Tuning
Instruction tuning fine-tunes language models on (instruction, output) pairs, enabling them to follow natural language instructions and generalize to unseen tasks.
Why Instruction Tuning?
Pre-trained language models possess knowledge but lack the ability to follow instructions directly. Instruction tuning bridges this gap.
| Approach | Input Format | Output | Limitation |
|---|---|---|---|
| Pre-training | Raw text | Continuation | No task following |
| Fine-tuning | Task-specific format | Task output | Task-specific |
| Instruction tuning | Natural language instruction | Instruction-following output | Generalization |
| RLHF | Instruction + feedback | Aligned output | Requires reward model |
Instruction Tuning Pipeline
FLAN (Fine-tuned LAnguage Net)
FLAN scales instruction tuning across diverse tasks, demonstrating emergent cross-task generalization.
FLAN Data Construction
from datasets import Dataset
import random
def format_flan_example(dataset_name, task_type, input_text, output_text):
templates = {
"SINGLE_LABEL_TASK": [
"Classify the following: {input}\nAnswer: {output}",
"What is the label for: {input}?\nAnswer: {output}",
],
"GENERATION_TASK": [
"Given the following input, generate a response.\nInput: {input}\nResponse: {output}",
],
"SUMMARIZATION_TASK": [
"Summarize the following: {input}\nSummary: {output}",
],
}
template = random.choice(templates.get(task_type, templates["SINGLE_LABEL_TASK"]))
return {
"input": template.format(input=input_text),
"output": output_text,
"dataset": dataset_name,
"task_type": task_type,
}
examples = [
format_flan_example("sst2", "SINGLE_LABEL_TASK", "This movie is great!", "Positive"),
format_flan_example("ag_news", "SINGLE_LABEL_TASK", "Stock market hits new high.", "Business"),
format_flan_example("xsum", "SUMMARIZATION_TASK", "Long article about climate policy...", "New climate legislation passed."),
]
dataset = Dataset.from_list(examples)
print(dataset[0])
FLAN-T5 Training
from transformers import T5ForConditionalGeneration, T5Tokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments
model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
def preprocess_flan(examples):
inputs = ["instruction: " + inp for inp in examples["input"]]
targets = examples["output"]
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -101) for l in label]
for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
training_args = Seq2SeqTrainingArguments(
output_dir="./flan-t5",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
learning_rate=3e-5,
weight_decay=0.01,
warmup_steps=1000,
evaluation_strategy="steps",
eval_steps=500,
predict_with_generate=True,
fp16=True,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
)
trainer.train()
InstructGPT (SFT + RLHF)
InstructGPT pioneered the two-stage approach: supervised fine-tuning followed by reinforcement learning from human feedback.
DfInstructGPT Training Objective
Stage 1 - SFT Loss:
Stage 2 - RLHF Objective:
where is the learned reward model and is the SFT model.
import torch
import torch.nn.functional as F
class InstructGPTTrainer:
def __init__(self, sft_model, ref_model, reward_model, tokenizer, beta=0.1):
self.sft_model = sft_model
self.ref_model = ref_model
self.reward_model = reward_model
self.tokenizer = tokenizer
self.beta = beta
def compute_sft_loss(self, batch):
outputs = self.sft_model(
input_ids=batch["input_ids"],
labels=batch["labels"]
)
return outputs.loss
def compute_reward(self, prompts, responses):
inputs = self.tokenizer(
[f"{p}\n{r}" for p, r in zip(prompts, responses)],
padding=True, truncation=True, return_tensors="pt"
)
with torch.no_grad():
rewards = self.reward_model(**inputs).logits
return rewards
def compute_rlhf_loss(self, batch):
prompts = batch["prompts"]
responses = batch["responses"]
policy_logprobs = self._get_logprobs(self.sft_model, prompts, responses)
with torch.no_grad():
ref_logprobs = self._get_logprobs(self.ref_model, prompts, responses)
kl = policy_logprobs - ref_logprobs
rewards = self.compute_reward(prompts, responses)
loss = -(rewards - self.beta * kl).mean()
return loss
Multi-Task Instruction Formatting
| Format | Template | Use Case |
|---|---|---|
| Alpaca | ### Instruction:\n{instruction}\n\n### Response:\n{output} | General |
| Vicuna | USER: {input}\nASSISTANT: {output} | Chat |
| ChatML | <im_start>user\n{input}<im_end>\n<im_start>assistant\n{output} | Multi-turn |
| ShareGPT | Human: {input}\nAssistant: {output} | Dialogue |
Scaling Laws for Instruction Tuning
DfScaling Behavior
The performance of instruction-tuned models follows a power law:
where is model parameter count, is dataset size, and are fitted constants. Benefits plateau beyond approximately 100K instruction examples.
| Model Size | Training Tokens | Benchmark Score | Key Insight |
|---|---|---|---|
| 8B | 1M | 62.3 | Basic instruction following |
| 70B | 1M | 74.8 | Strong generalization |
| 70B | 10M | 78.2 | Diminishing returns |
| 70B | 100M | 79.1 | Plateau reached |
Evaluation Frameworks
| Benchmark | Tasks | Metric | Focus |
|---|---|---|---|
| FLAN | 62 NLP tasks | Accuracy | Cross-task generalization |
| AlpacaEval | Open-ended | Win rate vs GPT-4 | Instruction following |
| MT-Bench | Multi-turn | GPT-4 rating | Dialogue quality |
| IFEval | Verifiable | Constraint compliance | Format adherence |
Best Practices
- Data quality over quantity - 10K high-quality examples outperform 100K noisy ones
- Diverse task coverage - Include classification, generation, reasoning, and extraction
- Consistent formatting - Standardize instruction templates across all examples
- Appropriate model size - Larger models benefit more from instruction tuning
- Curated learning rate - 2e-5 to 5e-5 works well for most base models
Key Takeaways
- FLAN demonstrates that multi-task instruction tuning enables emergent cross-task generalization
- InstructGPT established the SFT then RLHF paradigm used by most modern assistants
- Data quality matters more than data volume for instruction tuning
- Format consistency across instruction templates improves model learning
- Evaluation requires both automatic metrics and human judgment for reliable assessment