LSTM and GRU
Long Short-Term Memory (LSTM) and Gated Recurrent Units (GRU) solve the vanishing gradient problem by introducing gating mechanisms that regulate information flow through the network.
LSTM Gates
Forget Gate
Input Gate
Candidate Values
Cell State Update
Output Gate
LSTM Implementation
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers,
num_classes, dropout=0.5):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(
embed_dim, hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=True
)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
embedded = self.dropout(self.embedding(text))
output, (hidden, cell) = self.lstm(embedded)
# Concatenate final forward and backward hidden states
hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
hidden = self.dropout(hidden)
return self.fc(hidden)
# Initialize
model = LSTMClassifier(
vocab_size=25000, embed_dim=300,
hidden_dim=256, num_layers=2,
num_classes=2, dropout=0.5
)
# Forward pass
batch = torch.randint(0, 25000, (32, 100))
output = model(batch)
print(output.shape) # (32, 2)
GRU Architecture
GRU simplifies LSTM by combining the forget and input gates into a single update gate.
GRU Update Gate
GRU Reset Gate
GRU Candidate
GRU Output
GRU Implementation
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers,
num_classes, dropout=0.5):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.gru = nn.GRU(
embed_dim, hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=True
)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
embedded = self.dropout(self.embedding(text))
output, hidden = self.gru(embedded)
hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
return self.fc(self.dropout(hidden))
LSTM vs GRU Comparison
| Aspect | LSTM | GRU |
|---|---|---|
| Gates | 3 (forget, input, output) | 2 (update, reset) |
| Parameters | More | Fewer |
| Training speed | Slower | Faster |
| Performance | Slightly better | Comparable |
| Memory usage | Higher | Lower |
| Cell state | Separate | Combined with hidden |
Bidirectional RNNs
Process sequences in both directions to capture full context.
# Forward: "The cat sat" -> h_forward
# Backward: "sat cat The" -> h_backward
# Combined: [h_forward; h_backward]
bilstm = nn.LSTM(128, 256, bidirectional=True, batch_first=True)