Recurrent Neural Networks (RNNs)
RNNs are neural networks designed for sequential data. Unlike feedforward networks, RNNs maintain a hidden state that captures information from previous time steps, making them naturally suited for NLP tasks like language modeling, text generation, and machine translation.
RNN Hidden State
RNN Output
Vanishing and Exploding Gradients
The fundamental challenge with vanilla RNNs is that gradients can vanish or explode during backpropagation through time (BPTT).
Gradient Through Time
| Problem | Cause | Effect | Solution |
|---|---|---|---|
| Vanishing gradient | Small eigenvalues of W_hh | Cannot learn long-range dependencies | LSTMs, GRUs |
| Exploding gradient | Large eigenvalues of W_hh | Unstable training | Gradient clipping |
Basic RNN Implementation
import torch
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, text, hidden=None):
embedded = self.embedding(text) # (batch, seq_len, embed_dim)
output, hidden = self.rnn(embedded, hidden)
prediction = self.fc(output)
return prediction, hidden
# Initialize model
model = SimpleRNN(vocab_size=10000, embed_dim=128,
hidden_dim=256, output_dim=10000)
# Forward pass
batch = torch.randint(0, 10000, (32, 50)) # (batch, seq_len)
output, hidden = model(batch)
print(output.shape) # (32, 50, 10000)
print(hidden.shape) # (1, 32, 256)
RNN for Text Classification
class RNNClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, text):
embedded = self.embedding(text)
output, hidden = self.rnn(embedded)
# Use last hidden state for classification
hidden = hidden.squeeze(0) # (batch, hidden_dim)
return self.fc(hidden)
RNN for Language Modeling
class RNNLM(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, text, hidden=None):
embedded = self.embedding(text)
output, hidden = self.rnn(embedded, hidden)
prediction = self.fc(output)
return prediction, hidden
def generate(self, start_token, max_len=50):
self.eval()
tokens = [start_token]
hidden = None
with torch.no_grad():
for _ in range(max_len):
x = torch.tensor([[tokens[-1]]])
output, hidden = self(x, hidden)
prob = torch.softmax(output[:, -1], dim=-1)
next_token = torch.argmax(prob, dim=-1).item()
tokens.append(next_token)
return tokens
RNN Variants
| Variant | Description | Advantage |
|---|---|---|
| Vanilla RNN | Basic recurrent structure | Simple, fast |
| LSTM | Long Short-Term Memory | Captures long dependencies |
| GRU | Gated Recurrent Unit | Simpler than LSTM |
| Bidirectional | Forward + backward pass | Full context |
| Deep RNN | Multiple stacked layers | More capacity |