🎉 75% of content is free forever — Unlock Premium from $10/mo →
CW
Search courses…
💼 Servicesℹ️ About✉️ ContactView Pricing Plansfrom $10

Optimizers: SGD, Adam, AdamW, Learning Rate Schedules — Asked at OpenAI & DeepMind

Deep Learning Premium InterviewsOptimization Algorithms⭐ Premium

Advertisement

OpenAI & DeepMind

Optimizers: SGD, Adam, AdamW & Learning Rate Schedules

Premium Interview Preparation — Optimization Mastery

🎯 The Interview Question

"Compare SGD with momentum, Adam, and AdamW optimizers. What are the mathematical formulations of each? Why is AdamW preferred over Adam for training Transformers? Explain learning rate schedules and their importance. When would you choose SGD over Adam?"

This question is fundamental for understanding how deep learning models are trained — essential for OpenAI and DeepMind.


📚 Detailed Answer

SGD with Momentum

Standard SGD:

wt+1=wtηL(wt)\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla \mathcal{L}(\mathbf{w}_t)

With Momentum:

vt=βvt1+L(wt)\mathbf{v}_t = \beta \mathbf{v}_{t-1} + \nabla \mathcal{L}(\mathbf{w}_t)
wt+1=wtηvt\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \mathbf{v}_t

where β\beta is typically 0.9.

Effect: Momentum accumulates past gradients, providing:

  • Faster convergence in consistent gradient directions
  • Dampening of oscillations
  • Ability to escape shallow local minima

💡

SGD with momentum is often preferred for computer vision tasks (training ResNets) because it tends to find flatter minima that generalize better. The learning rate schedule is critical — use cosine annealing or step decay.

Adam (Adaptive Moment Estimation)

Adam combines momentum with adaptive learning rates:

First moment (mean):

mt=β1mt1+(1β1)L(wt)\mathbf{m}_t = \beta_1 \mathbf{m}_{t-1} + (1-\beta_1)\nabla \mathcal{L}(\mathbf{w}_t)

Second moment (variance):

vt=β2vt1+(1β2)(L(wt))2\mathbf{v}_t = \beta_2 \mathbf{v}_{t-1} + (1-\beta_2)(\nabla \mathcal{L}(\mathbf{w}_t))^2

Bias correction:

m^t=mt1β1t,v^t=vt1β2t\hat{\mathbf{m}}_t = \frac{\mathbf{m}_t}{1-\beta_1^t}, \quad \hat{\mathbf{v}}_t = \frac{\mathbf{v}_t}{1-\beta_2^t}

Update:

wt+1=wtηm^tv^t+ϵ\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \frac{\hat{\mathbf{m}}_t}{\sqrt{\hat{\mathbf{v}}_t} + \epsilon}

Default hyperparameters: β1=0.9\beta_1 = 0.9, β2=0.999\beta_2 = 0.999, ϵ=108\epsilon = 10^{-8}

Advantages:

  • Adaptive learning rates per parameter
  • Fast convergence on sparse gradients
  • Works well with default hyperparameters

AdamW: Decoupled Weight Decay

Adam applies weight decay incorrectly — as L2 regularization:

Adam L2: wt+1=wtηm^tv^t+ϵηλwt\text{Adam L2: } \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \frac{\hat{\mathbf{m}}_t}{\sqrt{\hat{\mathbf{v}}_t} + \epsilon} - \eta \lambda \mathbf{w}_t

This couples weight decay with the adaptive learning rate, which is suboptimal.

AdamW fixes this:

wt+1=wtηm^tv^t+ϵηλwt\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \frac{\hat{\mathbf{m}}_t}{\sqrt{\hat{\mathbf{v}}_t} + \epsilon} - \eta \lambda \mathbf{w}_t

The weight decay term is applied directly to the weights, not through the gradient.

Why AdamW is better:

  • Proper decoupling of weight decay
  • Better generalization
  • Standard for Transformer training

Comparison Table

OptimizerProsConsBest For
SGDSimple, good generalizationSlow convergence, sensitive to LRComputer Vision
AdamFast convergence, adaptive LRCan generalize poorlySparse gradients, NLP
AdamWProper weight decay, good generalizationSlightly more computeTransformers
LAMBLarge batch trainingComplexDistributed training

Learning Rate Schedules

Step Decay

ηt=η0γt/s\eta_t = \eta_0 \cdot \gamma^{\lfloor t/s \rfloor}

Decays LR by factor γ\gamma every ss steps. Common: γ=0.1\gamma=0.1, s=30s=30 epochs.

Cosine Annealing

ηt=ηmin+12(ηmaxηmin)(1+cos(πtT))\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + \cos\left(\frac{\pi t}{T}\right)\right)

Smooth decay from ηmax\eta_{max} to ηmin\eta_{min} over TT steps. State-of-the-art for most tasks.

Warmup + Cosine Annealing

ηt={ηmaxtTwarmuptTwarmupcosine(tTwarmup)t>Twarmup\eta_t = \begin{cases} \eta_{max} \cdot \frac{t}{T_{warmup}} & t \leq T_{warmup} \\ \text{cosine}(t - T_{warmup}) & t > T_{warmup} \end{cases}

Linear warmup for first TwarmupT_{warmup} steps, then cosine decay. Essential for Transformers.

Cyclical Learning Rates

Oscillate between bounds:

ηt=ηmin+(ηmaxηmin)2tT1\eta_t = \eta_{min} + (\eta_{max} - \eta_{min})\left|\frac{2t}{T} - 1\right|

Can help escape local minima.

Advanced Optimizers

LAMB (Layer-wise Adaptive Moments)

For large batch training:

wt+1=wtηϕ(wt)ϕ(wt)wtm^t/(v^t+ϵ)\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \frac{\phi(\mathbf{w}_t)}{\|\phi(\mathbf{w}_t)\|} \cdot \frac{\|\mathbf{w}_t\|}{\|\hat{\mathbf{m}}_t / (\sqrt{\hat{\mathbf{v}}_t} + \epsilon)\|}

Enables batch sizes up to 32K for BERT training.

Lion (Google Brain)

Uses only sign of gradient:

wt+1=wtη(sign(β1mt+(1β1)L)+λwt)\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \left(\text{sign}(\beta_1 \mathbf{m}_t + (1-\beta_1)\nabla \mathcal{L}) + \lambda \mathbf{w}_t\right)

Memory efficient, good for large models.

Practical Guidelines

Follow-Up Questions

Q: Why does Adam sometimes generalize worse than SGD? A: Adaptive methods can converge to sharp minima that have high training loss but poor generalization. SGD with momentum tends to find flatter minima.

Q: What is the relationship between learning rate and batch size? A: Linear scaling rule: when batch size increases by kk, increase LR by kk. Works for SGD; Adam is more robust to batch size changes.

Q: How do you choose between warmup steps and total training steps? A: Warmup is typically 5-10% of total steps. More warmup needed for larger models and batch sizes. Start with 2000-4000 steps for most tasks.

Related Topics

Advertisement