Few-Shot and In-Context Learning
Few-shot learning enables NLP models to generalize from very limited labeled examples, reducing the need for large annotated datasets.
Learning Paradigm Comparison
In-Context Learning
DfIn-Context Learning Formulation
Given a prompt with demonstration examples and query , the model predicts:
No gradient updates are performed; the model learns purely from the context.
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class InContextLearner:
def __init__(self, model_name="gpt2-medium"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
def create_prompt(self, task_description, examples, query, max_examples=5):
"""Create in-context learning prompt."""
prompt_parts = [f"Task: {task_description}\n"]
for i, (input_text, output_text) in enumerate(examples[:max_examples]):
prompt_parts.append(f"Input: {input_text}\nOutput: {output_text}\n")
prompt_parts.append(f"Input: {query}\nOutput:")
return "\n".join(prompt_parts)
def predict(self, prompt, max_new_tokens=50, temperature=0.7):
"""Generate prediction from prompt."""
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
top_k=50,
top_p=0.95
)
generated = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return generated.strip().split("\n")[0]
def few_shot_classify(self, task, examples, query, class_names):
"""Perform few-shot classification."""
prompt = f"Classify the following text into one of these categories: {', '.join(class_names)}\n\n"
for text, label in examples:
prompt += f"Text: {text}\nLabel: {label}\n\n"
prompt += f"Text: {query}\nLabel:"
prediction = self.predict(prompt, max_new_tokens=10)
for class_name in class_names:
if class_name.lower() in prediction.lower():
return class_name
return class_names[0]
# Usage
learner = InContextLearner()
examples = [
("This movie is fantastic!", "positive"),
("Terrible waste of time.", "negative"),
("Average film, nothing special.", "neutral"),
]
query = "What an incredible experience!"
result = learner.few_shot_classify(
task="sentiment analysis",
examples=examples,
query=query,
class_names=["positive", "negative", "neutral"]
)
print(f"Prediction: {result}")
Prompt Engineering Strategies
| Strategy | Description | Example |
|---|---|---|
| Zero-shot | Direct instruction | "Classify: [text]" |
| One-shot | Single example | "Example: [input] -> [output]" |
| Few-shot | Multiple examples | "Examples: ..." |
| Chain-of-thought | Step-by-step reasoning | "Let's think step by step..." |
| Self-consistency | Multiple samples + voting | "Sample 3 times, take majority" |
Prompt Template Library
class PromptTemplates:
TEMPLATES = {
"sentiment": {
"zero_shot": "Classify the sentiment of the following text as positive, negative, or neutral:\n\n{text}\n\nSentiment:",
"few_shot": "Here are examples of sentiment classification:\n\n{examples}\n\nNow classify:\nText: {text}\nSentiment:",
"cot": "Let's analyze the sentiment step by step.\n\nText: {text}\n\nStep 1: Identify key words and phrases.\nStep 2: Determine if they are positive or negative.\nStep 3: Consider the overall tone.\n\nSentiment:"
},
"ner": {
"zero_shot": "Extract all named entities (person, organization, location) from the following text:\n\n{text}\n\nEntities:",
"few_shot": "Here are examples of entity extraction:\n\n{examples}\n\nNow extract entities from:\nText: {text}\nEntities:"
},
"summarization": {
"zero_shot": "Provide a concise summary of the following text:\n\n{text}\n\nSummary:",
"few_shot": "Here are examples of good summaries:\n\n{examples}\n\nNow summarize:\nText: {text}\nSummary:"
}
}
@classmethod
def format(cls, task, strategy, **kwargs):
template = cls.TEMPLATES[task][strategy]
return template.format(**kwargs)
# Chain-of-thought prompting example
class ChainOfThoughtPrompter:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def reason(self, question, context=None):
prompt = f"Question: {question}\n"
if context:
prompt += f"Context: {context}\n"
prompt += "\nLet's solve this step by step:\n"
prompt += "Step 1:"
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
temperature=0.3
)
reasoning = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:])
return reasoning
# Self-consistency for robust predictions
class SelfConsistency:
def __init__(self, model, tokenizer, n_samples=5):
self.model = model
self.tokenizer = tokenizer
self.n_samples = n_samples
def predict_with_confidence(self, prompt):
predictions = []
for _ in range(self.n_samples):
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(
**inputs,
max_new_tokens=50,
temperature=0.7,
do_sample=True
)
pred = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:])
predictions.append(pred.strip().split("\n")[0])
# Vote on most common prediction
from collections import Counter
vote_counts = Counter(predictions)
best_prediction = vote_counts.most_common(1)[0][0]
confidence = vote_counts[best_prediction] / self.n_samples
return {
"prediction": best_prediction,
"confidence": confidence,
"all_predictions": predictions
}
Few-Shot Learning with Prototypes
DfPrototypical Networks
Compute class prototypes as mean of support embeddings:
Classify query by nearest prototype:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class PrototypicalNLPClassifier:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.prototypes = {}
def encode(self, texts):
"""Encode texts into embeddings."""
inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings
def compute_prototypes(self, support_set):
"""Compute class prototypes from support set."""
for label, texts in support_set.items():
embeddings = self.encode(texts)
self.prototypes[label] = embeddings.mean(dim=0)
def classify(self, query, temperature=1.0):
"""Classify query by nearest prototype."""
query_embedding = self.encode([query])[0]
distances = {}
for label, prototype in self.prototypes.items():
distance = torch.norm(query_embedding - prototype)
distances[label] = distance
# Convert distances to probabilities
scores = {label: torch.exp(-dist / temperature) for label, dist in distances.items()}
total = sum(scores.values())
probabilities = {label: score / total for label, score in scores.items()}
prediction = max(probabilities, key=probabilities.get)
confidence = probabilities[prediction]
return {
"prediction": prediction,
"confidence": confidence.item(),
"probabilities": {k: v.item() for k, v in probabilities.items()}
}
def few_shot_predict(self, support_set, query, k=5):
"""Perform k-shot classification."""
# Sample k examples per class
sampled_support = {}
for label, texts in support_set.items():
sampled_support[label] = texts[:k]
self.compute_prototypes(sampled_support)
return self.classify(query)
# Usage
classifier = PrototypicalNLPClassifier()
support_set = {
"positive": [
"This product is amazing!",
"Best purchase I ever made",
"Exceeded my expectations",
],
"negative": [
"Terrible quality",
"Complete waste of money",
"Very disappointed",
]
}
result = classifier.few_shot_predict(support_set, "Pretty good overall")
print(f"Prediction: {result['prediction']} ({result['confidence']:.2%})")
Few-Shot Best Practices
| Strategy | Recommended Examples | Key Consideration |
|---|---|---|
| Diverse examples | 3-5 per class | Cover edge cases |
| Representative examples | Match distribution | Reduce bias |
| Balanced classes | Equal per class | Prevent majority bias |
| Clear formatting | Consistent structure | Reduce ambiguity |
| Task description | Always include | Set expectations |
Key Takeaways
- In-context learning enables LLMs to learn tasks without gradient updates
- Chain-of-thought prompting improves reasoning on complex tasks
- Self-consistency increases prediction robustness through sampling
- Prototypical networks provide interpretable few-shot classification
- Prompt engineering is critical for maximizing few-shot performance