Relation Extraction
Relation extraction identifies semantic relationships between entities in text. It's crucial for building knowledge graphs and understanding document content.
Relation Types
| Relation | Example | Direction |
|---|---|---|
| Works for | "Elon Musk works for Tesla" | (Elon Musk, WorksFor, Tesla) |
| Located in | "Paris is in France" | (Paris, LocatedIn, France) |
| Founded | "Bezos founded Amazon" | (Bezos, Founded, Amazon) |
| Parent of | "Luke is the son of Vader" | (Luke, ChildOf, Vader) |
| Born in | "Einstein was born in Germany" | (Einstein, BornIn, Germany) |
BERT-based Relation Extraction
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
class BertForRelationExtraction(nn.Module):
def __init__(self, model_name, num_relations):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
hidden_size = self.bert.config.hidden_size
# Entity markers
self.entity_start = nn.Embedding(1, hidden_size)
self.entity_end = nn.Embedding(1, hidden_size)
# Relation classifier
self.classifier = nn.Sequential(
nn.Linear(hidden_size * 4, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_relations)
)
def forward(self, input_ids, attention_mask, entity1_pos, entity2_pos):
# Get BERT outputs
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state # (batch, seq_len, hidden)
# Extract entity representations
e1_repr = self.extract_entity(sequence_output, entity1_pos)
e2_repr = self.extract_entity(sequence_output, entity2_pos)
# Concatenate entity representations
combined = torch.cat([
e1_repr,
e2_repr,
e1_repr * e2_repr,
e1_repr - e2_repr
], dim=-1)
# Classify relation
logits = self.classifier(combined)
return logits
def extract_entity(self, sequence_output, positions):
"""Extract entity representation using marked positions"""
batch_size = sequence_output.size(0)
entity_repr = sequence_output[
torch.arange(batch_size),
positions
]
return entity_repr
# Usage
model = BertForRelationExtraction("bert-base-uncased", num_relations=20)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text = "Elon Musk founded SpaceX in 2002"
inputs = tokenizer(text, return_tensors="pt")
# Entity positions (simplified)
entity1_pos = torch.tensor([1]) # Position of "Elon"
entity2_pos = torch.tensor([4]) # Position of "SpaceX"
logits = model(inputs["input_ids"], inputs["attention_mask"],
entity1_pos, entity2_pos)
pred_relation = torch.argmax(logits, dim=-1)
Open Information Extraction
Open IE extracts relation triples without predefined schema.
from transformers import pipeline
class OpenIERe extractor:
def __init__(self):
self.pipe = pipeline(
"text2text-generation",
model="Babelscape/rebel-large"
)
def extract_relations(self, text):
output = self.pipe(
f"<ent> {text} <rel>",
num_beams=3,
max_length=256
)
return self.parse_output(output[0]["generated_text"])
def parse_output(self, text):
relations = []
triples = text.split("<sep>")
for triple in triples:
if "<ent>" in triple and "<rel>" in triple:
parts = triple.split("<rel>")
if len(parts) == 2:
subject = parts[0].replace("<ent>", "").strip()
relation = parts[1].split("<ent>")[0].strip()
obj = parts[1].split("<ent>")[-1].strip()
relations.append({
"subject": subject,
"relation": relation,
"object": obj
})
return relations
# Usage
extractor = OpenIERe extractor()
relations = extractor.extract_relations(
"Apple was founded by Steve Jobs in California in 1976."
)
for r in relations:
print(f"({r['subject']}, {r['relation']}, {r['object']})")
Relation Classification Datasets
| Dataset | Relations | Domain | Size |
|---|---|---|---|
| TACRED | 42 | News/Wikipedia | 106,264 |
| SemEval-2010 | 9 | News | 1,071 |
| FewRel | 80 | Wikipedia | 70,000 |
| DocRED | 96 | Documents | 132,375 |
| NYT-Freebase | 24 | News | 45,907 |
Evaluation Metrics
| Metric | Description | Range |
|---|---|---|
| Micro F1 | Per-triple evaluation | 0-100 |
| Macro F1 | Per-relation evaluation | 0-100 |
| Precision | Correct relations / predicted | 0-100 |
| Recall | Correct relations / actual | 0-100 |
DfMicro F1 Score
Distant Supervision
def distant_supervision(kb_facts, corpus):
"""Generate training data from knowledge base"""
training_data = []
for fact in kb_facts:
subject, relation, obj = fact
# Search for sentences containing both entities
for sentence in corpus:
if subject in sentence and obj in sentence:
training_data.append({
"text": sentence,
"subject": subject,
"object": obj,
"relation": relation
})
return training_data
# Example
kb = [
("Apple", "FoundedBy", "Steve Jobs"),
("Tesla", "CEO", "Elon Musk"),
]
corpus = [
"Steve Jobs co-founded Apple in 1976.",
"Apple released the iPhone in 2007.",
"Elon Musk is the CEO of Tesla.",
]
training_data = distant_supervision(kb, corpus)
for d in training_data:
print(f"{d['subject']} --{d['relation']}--> {d['object']}")
Few-Shot Relation Extraction
class FewShotRelationExtractor:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model_name)
self.prototypes = {}
def add_support_example(self, text, relation):
"""Add a support example for a relation"""
embedding = self.model.encode(text)
if relation not in self.prototypes:
self.prototypes[relation] = []
self.prototypes[relation].append(embedding)
def compute_prototypes(self):
"""Compute mean embeddings for each relation"""
self.relation_prototypes = {}
for relation, embeddings in self.prototypes.items():
self.relation_prototypes[relation] = np.mean(embeddings, axis=0)
def predict(self, text):
"""Predict relation using prototype matching"""
query_embedding = self.model.encode(text)
best_relation = None
best_score = -1
for relation, prototype in self.relation_prototypes.items():
score = cosine_similarity(query_embedding, prototype)
if score > best_score:
best_score = score
best_relation = relation
return {"relation": best_relation, "score": best_score}
# Usage
extractor = FewShotRelationExtractor()
# Add support examples
extractor.add_support_example("Apple was founded by Steve Jobs", "FoundedBy")
extractor.add_support_example("Steve Jobs created Apple", "FoundedBy")
extractor.add_support_example("Jeff Bezos founded Amazon", "FoundedBy")
extractor.compute_prototypes()
# Predict
result = extractor.predict("Bill Gates co-founded Microsoft")
print(result) # {'relation': 'FoundedBy', 'score': 0.87}
Relation Extraction Pipeline
Relation extraction is typically formulated as a multi-class classification problem. However, distant supervision and open IE approaches allow for more scalable relation extraction without manual annotation.