Sequence-to-Sequence Models
Sequence-to-sequence (seq2seq) models map input sequences to output sequences of potentially different lengths. They're foundational for machine translation, text summarization, dialogue systems, and code generation.
How Seq2Seq Works
- Encoder reads input token-by-token and produces hidden states
- Context vector (final encoder hidden state) summarizes the input
- Decoder generates output tokens using the context vector
- Each decoder step uses previous output as input for next step
Encoder Hidden State
Decoder Generation
Teacher Forcing
During training, the decoder receives the actual previous token (ground truth) instead of its own prediction.
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, trg, teacher_forcing_ratio=0.5):
batch_size = src.shape[0]
trg_len = trg.shape[1]
trg_vocab_size = self.decoder.output_dim
outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
# Encode
hidden, cell = self.encoder(src)
# First decoder input is <sos> token
input = trg[:, 0]
for t in range(1, trg_len):
output, hidden, cell = self.decoder(input, hidden, cell)
outputs[:, t] = output
# Teacher forcing
teacher_force = torch.rand(1) < teacher_forcing_ratio
top1 = output.argmax(1)
input = trg[:, t] if teacher_force else top1
return outputs
Beam Search Decoding
Beam search explores multiple candidate sequences to find a high-probability output.
def beam_search(model, src, beam_width=5, max_len=50, sos_idx=1, eos_idx=2):
hidden, cell = model.encoder(src)
# Initialize beams: (log_probability, token_sequence, hidden, cell)
beams = [(0.0, [sos_idx], hidden, cell)]
completed = []
for _ in range(max_len):
all_candidates = []
for score, seq, h, c in beams:
if seq[-1] == eos_idx:
completed.append((score, seq))
continue
output, new_h, new_c = model.decoder(
torch.tensor([seq[-1]]), h, c
)
log_probs = torch.log_softmax(output[:, -1], dim=-1)
topk = torch.topk(log_probs, beam_width)
for i in range(beam_width):
token = topk.indices[0][i].item()
new_score = score + topk.values[0][i].item()
all_candidates.append((new_score, seq + [token], new_h, new_c))
# Keep top beam_width candidates
beams = sorted(all_candidates, key=lambda x: x[0], reverse=True)[:beam_width]
completed.extend([(score, seq) for score, seq, _, _ in beams])
return max(completed, key=lambda x: x[0] / len(x[1]))
Encoder-Decoder Implementation
class Encoder(nn.Module):
def __init__(self, input_dim, embed_dim, hidden_dim, n_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(input_dim, embed_dim)
self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers,
batch_first=True, bidirectional=True)
self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
self.fc_cell = nn.Linear(hidden_dim * 2, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
embedded = self.dropout(self.embedding(src))
outputs, (hidden, cell) = self.rnn(embedded)
# Combine bidirectional states
hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)
cell = torch.cat((cell[-2], cell[-1]), dim=1)
hidden = torch.tanh(self.fc_hidden(hidden))
cell = torch.tanh(self.fc_cell(cell))
return hidden, cell
class Decoder(nn.Module):
def __init__(self, output_dim, embed_dim, hidden_dim, n_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(output_dim, embed_dim)
self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True)
self.fc_out = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden, cell):
input = input.unsqueeze(1)
embedded = self.dropout(self.embedding(input))
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
prediction = self.fc_out(output.squeeze(1))
return prediction, hidden, cell
Seq2Seq Applications
| Application | Input | Output | Key Challenge |
|---|---|---|---|
| Machine Translation | English text | French text | Word order differences |
| Text Summarization | Long document | Short summary | Maintaining key information |
| Dialogue Systems | User query | System response | Coherence, context |
| Code Generation | Natural language | Source code | Syntax correctness |
| Image Captioning | Image features | Description text | Multimodal alignment |
| Speech Recognition | Audio features | Transcript | Real-time processing |
Attention Mechanism
The context vector bottleneck limits seq2seq models. Attention allows the decoder to focus on different parts of the input at each step.
Attention Score
Context Vector with Attention
class Attention(nn.Module):
def __init__(self, enc_dim, dec_dim):
super().__init__()
self.attn = nn.Linear(enc_dim + dec_dim, dec_dim)
self.v = nn.Linear(dec_dim, 1)
def forward(self, hidden, encoder_outputs):
src_len = encoder_outputs.shape[1]
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2)))
attention = self.v(energy).squeeze(2)
return torch.softmax(attention, dim=1)
Evaluation Metrics
BLEU Score
| Metric | Range | Best For |
|---|---|---|
| BLEU | 0-1 | Machine translation |
| ROUGE | 0-1 | Summarization |
| METEOR | 0-1 | Translation quality |
| CIDEr | 0-β | Image captioning |