🎉 75% of content is free forever — Unlock Premium from $10/mo →
CW
Search courses…
💼 Servicesℹ️ About✉️ ContactView Pricing Plansfrom $10

LSTM Networks — Gates, Cell State and Bidirectional Architectures

Sequence ModelsLSTM🟢 Free Lesson

Advertisement

Sequence Models

LSTM Networks — Long Short-Term Memory for Sequential Data

LSTM networks solve the vanishing gradient problem of vanilla RNNs by using gating mechanisms to control information flow across time steps. The cell state acts as an information highway.

  • Three Gates Control Flow — Forget, input, and output gates decide what to discard, store, and output
  • Cell State Prevents Vanishing — Linear addition preserves gradient magnitude across hundreds of time steps
  • Bidirectional Uses Context — Processing in both directions captures past and future information

"LSTMs gave RNNs a long-term memory that doesn't fade."

LSTM Networks — Gates, Cell State and Bidirectional Architectures

LSTM (Long Short-Term Memory) networks solve the vanishing gradient problem of vanilla RNNs by using gating mechanisms to control information flow across time steps.

See our RNN Deep Dive tutorial for the fundamentals of vanilla RNNs and why LSTMs were invented.


Why LSTMs?

DfThe Problem LSTMs Solve

Vanilla RNNs suffer from vanishing gradients, limiting their ability to learn long-range dependencies. For sequences longer than ~20 steps, the gradient becomes too small to update early weights effectively.

LSTMs solve this by introducing a cell state — an information highway that allows gradients to flow through time with minimal degradation. The gates control what information is added to, kept in, or removed from this highway.


LSTM Cell Architecture

DfLSTM Cell

An LSTM cell has four components:

  1. Forget Gate: Decides what to throw away from the cell state
  2. Input Gate: Decides what new information to store
  3. Cell Candidate: Creates candidate values to add
  4. Output Gate: Decides what to output based on cell state

The cell state ct\mathbf{c}_t acts as a conveyor belt, carrying information across time steps with only linear interactions (addition), preventing vanishing gradients.


The Three Gates

Forget Gate

DfForget Gate

The forget gate decides what to discard from the cell state:

ft=σ(Wf[ht1,xt]+bf)\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)

Output: Vector in (0,1)d(0, 1)^d where dd is the hidden dimension. Values near 0 mean "forget this"; values near 1 mean "keep this."

Forget Gate

ft=σ(Wf[ht1,xt]+bf)\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)

Here,

  • ft\mathbf{f}_t=Forget gate activation (0 to 1 per dimension)
  • σ\sigma=Sigmoid function (outputs 0-1)
  • [ht1,xt][\mathbf{h}_{t-1}, \mathbf{x}_t]=Concatenation of previous hidden state and current input
  • Wf,bf\mathbf{W}_f, \mathbf{b}_f=Forget gate weights and bias

Input Gate

DfInput Gate

The input gate decides what new information to store:

it=σ(Wi[ht1,xt]+bi)\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i)
c~t=tanh(Wc[ht1,xt]+bc)\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)

it\mathbf{i}_t controls which values to update; c~t\tilde{\mathbf{c}}_t creates candidate values to add.

Input Gate

it=σ(Wi[ht1,xt]+bi),c~t=tanh(Wc[ht1,xt]+bc)\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i), \quad \tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)

Here,

  • it\mathbf{i}_t=Input gate activation (what to update)
  • c~t\tilde{\mathbf{c}}_t=Candidate cell state values
  • tanh\tanh=Tanh function (outputs -1 to 1)

Cell State Update

DfCell State Update

The cell state is updated by forgetting old information and adding new information:

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

The key insight: this is a linear operation (no matrix multiplication). Gradients flow through cell state updates without vanishing because addition preserves gradient magnitude.

LSTM Cell State Update
ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

Output Gate

DfOutput Gate

The output gate decides what to output from the cell state:

ot=σ(Wo[ht1,xt]+bo)\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o)
ht=ottanh(ct)\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)

The hidden state is a filtered version of the cell state.

Output Gate

ot=σ(Wo[ht1,xt]+bo),ht=ottanh(ct)\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o), \quad \mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)

Here,

  • ot\mathbf{o}_t=Output gate activation (what to output)
  • ht\mathbf{h}_t=Final hidden state
  • ct\mathbf{c}_t=Updated cell state
  • \odot=Element-wise multiplication

Complete LSTM Equations

DfComplete LSTM Forward Pass

At each time step tt:

  1. ft=σ(Wf[ht1,xt]+bf)\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) — Forget gate
  2. it=σ(Wi[ht1,xt]+bi)\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) — Input gate
  3. c~t=tanh(Wc[ht1,xt]+bc)\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c) — Cell candidate
  4. ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t — Cell state update
  5. ot=σ(Wo[ht1,xt]+bo)\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o) — Output gate
  6. ht=ottanh(ct)\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t) — Hidden state
LSTM Cell ArchitectureCell State (c_t) — Information HighwayForget Gatef_t = σ(W_f·[h,x])Controls what to forgetInput Gatei_t = σ(W_i·[h,x])Controls what to addCell Candidatec̃_t = tanh(W_c·[h,x])New memory valuesOutput Gateo_t = σ(W_o·[h,x])Controls what to outputHidden Stateh_t = o_t ⊙ tanh(c_t)Filtered outputx_t (input)h_t (output)Key InsightCell state update is linear: c_t = f_t ⊙ c_{t-1} + i_t ⊙ c̃_tGradients flow through addition → no vanishing gradients

How this diagram works: This diagram shows the internal architecture of a single LSTM cell. The cell state (blue bar at top) acts as an information highway, flowing across time steps with only linear interactions (multiplication and addition). The three gates — forget (red), input (green), and output (purple) — are sigmoid neural networks that output values between 0 and 1, acting as soft switches. The forget gate controls what to discard from the cell state, the input gate controls what new information to store, and the output gate controls what to reveal. The critical insight is that the cell state update is linear: ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t. Because gradients flow through addition (not matrix multiplication), they are preserved across many time steps, solving the vanishing gradient problem that plagues vanilla RNNs.


Why LSTMs Solve Vanishing Gradients

ThLSTM Gradient Flow

The cell state gradient through time satisfies:

ctct1=diag(ft)+higher order terms\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \text{diag}(\mathbf{f}_t) + \text{higher order terms}

When the forget gate ft1\mathbf{f}_t \approx 1, the gradient is approximately the identity matrix, preserving gradient magnitude across time steps. The forget gate acts as a gradient highway — when it learns to keep values (output near 1), gradients flow unattenuated.

Forget Gate Bias Initialization

Initialize the forget gate bias to 1.0 (or larger, e.g., 2.0) instead of 0. This makes σ(1)0.73\sigma(1) \approx 0.73, encouraging the network to initially keep most information. With bias=0, σ(0)=0.5\sigma(0) = 0.5 causes aggressive forgetting early in training.


Bidirectional LSTM

DfBidirectional LSTM

A bidirectional LSTM processes the sequence in both directions:

ht=LSTM(xt,ht1)\overrightarrow{\mathbf{h}_t} = \overrightarrow{\text{LSTM}}(\mathbf{x}_t, \overrightarrow{\mathbf{h}_{t-1}})
ht=LSTM(xt,ht+1)\overleftarrow{\mathbf{h}_t} = \overleftarrow{\text{LSTM}}(\mathbf{x}_t, \overleftarrow{\mathbf{h}_{t+1}})
ht=[ht;ht]\mathbf{h}_t = [\overrightarrow{\mathbf{h}_t}; \overleftarrow{\mathbf{h}_t}]

Each output is the concatenation of forward and backward hidden states. This allows the model to use both past and future context, critical for tasks like NER, POS tagging, and machine translation.

Bidirectional LSTM Output

ht=[ht;ht]R2d\mathbf{h}_t = [\overrightarrow{\mathbf{h}_t}; \overleftarrow{\mathbf{h}_t}] \in \mathbb{R}^{2d}

Here,

  • ht\overrightarrow{\mathbf{h}_t}=Forward hidden state
  • ht\overleftarrow{\mathbf{h}_t}=Backward hidden state
  • dd=Hidden dimension (output is 2d)

When to Use Bidirectional

  • Use bidirectional when you have the full sequence available (NER, sentiment analysis, machine translation encoder)
  • Use unidirectional when you must predict online (time series forecasting, streaming data, decoder)
  • Bidirectional doubles parameters and computation

Stacked LSTM

DfStacked LSTM

Stack multiple LSTM layers by feeding the output of one layer as input to the next:

ht(l)=LSTM(l)(ht(l1),ht1(l))\mathbf{h}_t^{(l)} = \text{LSTM}^{(l)}(\mathbf{h}_t^{(l-1)}, \mathbf{h}_{t-1}^{(l)})

where ht(0)=xt\mathbf{h}_t^{(0)} = \mathbf{x}_t. Deeper LSTMs can learn more abstract representations, but are harder to train and prone to overfitting.


PyTorch Implementation

Example: Complete LSTM in PyTorch

import torch
import torch.nn as nn

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim,
                 output_dim, num_layers=2, bidirectional=True, dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0
        )
        lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.fc = nn.Linear(lstm_output_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

        # Initialize forget gate bias to 1
        for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                n = param.size(0)
                param.data[n//4:n//2].fill_(1.0)  # forget gate bias

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.dropout(self.embedding(x))
        # embedded: (batch, seq_len, embed_dim)

        output, (hidden, cell) = self.lstm(embedded)
        # output: (batch, seq_len, hidden_dim * num_directions)
        # hidden: (num_layers * num_directions, batch, hidden_dim)

        # Use last hidden state from the last layer
        if self.lstm.bidirectional:
            hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        else:
            hidden = hidden[-1]

        # hidden: (batch, hidden_dim * num_directions)
        return self.fc(self.dropout(hidden))

# Test
model = LSTMClassifier(
    vocab_size=10000,
    embed_dim=128,
    hidden_dim=256,
    output_dim=5,
    num_layers=2,
    bidirectional=True
)

x = torch.randint(0, 10000, (32, 100))  # batch=32, seq_len=100
out = model(x)
print(f"Output shape: {out.shape}")  # [32, 5]
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Example: LSTM for Time Series Forecasting

import torch
import torch.nn as nn

class LSTMForecaster(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=2, forecast_horizon=24):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.1
        )
        self.fc = nn.Linear(hidden_size, forecast_horizon)

    def forward(self, x):
        # x: (batch, lookback, features)
        lstm_out, (h_n, c_n) = self.lstm(x)
        # Use last time step's output
        last_hidden = lstm_out[:, -1, :]
        return self.fc(last_hidden)

# Test with synthetic time series
model = LSTMForecaster(input_size=3, hidden_size=64, forecast_horizon=24)
x = torch.randn(32, 168, 3)  # 168 hours of history, 3 features
pred = model(x)
print(f"Prediction shape: {pred.shape}")  # [32, 24] (next 24 hours)

Example: Bidirectional LSTM for NER

import torch
import torch.nn as nn

class BiLSTM_NER(nn.Module):
    def __init__(self, vocab_size, tagset_size, embed_dim=128,
                 hidden_dim=256, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim // 2,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.3
        )
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, x):
        embeds = self.embedding(x)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        return tag_space  # (batch, seq_len, num_tags)

# Test for NER (token-level classification)
model = BiLSTM_NER(vocab_size=5000, tagset_size=9)  # 9 NER tags
tokens = torch.randint(0, 5000, (16, 50))  # batch=16, seq_len=50
logits = model(tokens)
print(f"NER output shape: {logits.shape}")  # [16, 50, 9]

# Cross-entropy loss for token classification
tags = torch.randint(0, 9, (16, 50))
loss = nn.CrossEntropyLoss()(logits.view(-1, 9), tags.view(-1))
print(f"NER loss: {loss.item():.4f}")

LSTM vs. GRU Comparison

FeatureLSTMGRU
Gates3 (forget, input, output)2 (update, reset)
Cell stateYesNo
ParametersMore (4×d24 \times d^2)Fewer (3×d23 \times d^2)
SpeedSlowerFaster
PerformanceSimilarSimilar
MemoryBetter for very long sequencesCompetitive

When to Use Which

  • LSTM: Long sequences, when you need fine-grained control over memory
  • GRU: Shorter sequences, faster training, fewer parameters
  • Transformers: Now dominant for most sequence tasks, but LSTMs still useful for streaming/online settings

Summary

Summary: LSTM Networks

  • LSTM solves vanishing gradients via gating mechanisms and cell state
  • Forget gate: Controls what to discard from memory (ft=σ()\mathbf{f}_t = \sigma(\cdot))
  • Input gate: Controls what to add to memory (it=σ()\mathbf{i}_t = \sigma(\cdot))
  • Cell state update: Linear operation preserves gradient flow (ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t)
  • Output gate: Controls what to output (ot=σ()\mathbf{o}_t = \sigma(\cdot))
  • Bidirectional: Processes sequences in both directions, uses past and future context
  • Stacked: Multiple layers for abstract representations
  • Forget gate bias: Initialize to 1.0 for better gradient flow
  • LSTMs still useful: Streaming data, edge devices, when Transformers are too large

Practice Exercises

  1. Mathematical: Derive the gradient ctct1\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} for the LSTM cell state update. Why does this prevent vanishing gradients?

  2. Coding: Implement an LSTM cell from scratch using only torch.tensor operations (no nn.LSTM). Verify that your implementation matches PyTorch's output.

  3. Experiment: Compare LSTM vs. GRU vs. Transformer on a long-range dependency task (e.g., copying the first element of a 500-element sequence). Which handles long sequences best?

  4. Application: Build a bidirectional LSTM for sentiment analysis on IMDB reviews. Compare with a Transformer-based model. What are the tradeoffs?

  5. Research: Read the original LSTM paper (Hochreiter and Schmidhuber, 1997). What was the original motivation? How has the architecture evolved?


What to Learn Next

-> GRU Networks Gated Recurrent Unit — a simpler alternative to LSTM with similar performance.

-> RNN Deep Dive Vanilla RNNs, BPTT, and why vanishing gradients motivated the invention of LSTM and GRU.

-> Sequence to Sequence Encoder-decoder architectures for translation, summarization, and text generation.

-> Attention Mechanisms Self-attention, multi-head attention, and how Transformers replaced RNNs for sequence modeling.

-> Vision Transformers ViT, DeiT, Swin — how Transformers are challenging CNN dominance in computer vision.

-> DL Systems Design Designing production ML systems — data pipelines, training infrastructure, and serving at scale.

Premium Content

LSTM Networks — Gates, Cell State and Bidirectional Architectures

Unlock this lesson and 900+ advanced tutorials with a Premium plan.

🎯End-to-end Projects
💼Interview Prep
📜Certificates
🤝Community Access

Already a member? Log in

Need Expert Deep Learning Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement