1. Training Loop Anatomy
Every deep learning model, regardless of architecture, follows the same fundamental loop: forward pass β compute loss β backward pass β update parameters.
For each epoch:
For each batch:
1. Forward pass: Ε· = f(x; ΞΈ)
2. Compute loss: L = Loss(y, Ε·)
3. Zero gradients: βΞΈ β 0
4. Backward pass: βΞΈ = βL/βΞΈ (autograd)
5. Update: ΞΈ β ΞΈ β Ξ· Β· βΞΈ
The PyTorch Implementation
import torch
import torch.nn as nn
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0.0
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
# 1. Forward pass
predictions = model(batch_x)
# 2. Compute loss
loss = criterion(predictions, batch_y)
# 3. Zero gradients
optimizer.zero_grad()
# 4. Backward pass
loss.backward()
# 5. Update parameters
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
Key subtleties:
model.train()enables dropout and batch normalization training modeoptimizer.zero_grad()must be called before.backward()Β gradients accumulate by default.item()extracts the scalar loss value (detaches from the computation graph)
2. Loss Functions
Loss functions quantify the mismatch between predictions and targets. The choice of loss function encodes what "good" means for your task.
2.1 Mean Squared Error (MSE)
For regression tasks:
Gradient (per sample):
Properties:
- Convex for linear models β guarantees a single global minimum
- Penalizes large errors quadratically β sensitive to outliers
- Equivalent to maximizing Gaussian log-likelihood with fixed variance
2.2 Cross-Entropy Loss
For multi-class classification with classes:
where is one-hot encoded and :
The combined CrossEntropyLoss in PyTorch applies log-softmax + NLL loss numerically:
Numerical stability: The log-sum-exp trick computes where .
2.3 Focal Loss
Addressing class imbalance (e.g., object detection where 99% of anchors are background):
where:
- = model's estimated probability for the correct class
- = focusing parameter (typically 2)
- = class balancing weight
When , focal loss reduces to standard cross-entropy. When , well-classified examples () have their loss reduced by .
2.4 Contrastive Loss
For learning embeddings where similar items are close and dissimilar items are far apart:
where is the Euclidean distance between embeddings, for similar pairs, and is the margin.
Triplet Loss (used in FaceNet):
where = anchor, = positive (same class), = negative (different class), = margin.
Loss Function Selection Guide
| Task | Loss Function | Why |
|---|---|---|
| Regression | MSE, MAE, Huber | MSE for clean data, Huber for outliers |
| Binary classification | BCE, Focal | Focal for imbalanced data |
| Multi-class classification | CrossEntropy, Focal | Focal for long-tailed distributions |
| Metric learning | Contrastive, Triplet | Learn embedding space structure |
| Segmentation | Dice loss, CE+Dice | Handle severe foreground/background imbalance |
| GANs | Adversarial loss | Minimax game between generator and discriminator |
3. Optimizers
3.1 SGD (Stochastic Gradient Descent)
The simplest update rule:
Problem: Oscillates along high-curvature directions, converges slowly along flat directions.
3.2 SGD with Momentum
Adds a velocity term that accumulates past gradients:
Commonly . Momentum accelerates convergence in consistent gradient directions and dampens oscillations.
Physical analogy: A ball rolling downhill accumulates velocity. controls friction Β lower means more friction.
3.3 RMSProp
Adapts the learning rate per parameter based on the magnitude of recent gradients:
Parameters with large gradients get a smaller effective learning rate; parameters with small gradients get a larger one. Default: , .
3.4 Adam (Adaptive Moment Estimation)
Combines momentum (first moment) and RMSProp (second moment):
Bias correction (critical in early steps):
Update:
Defaults: , , .
3.5 AdamW (Adam with Decoupled Weight Decay)
In Adam, L2 regularization () is absorbed into the adaptive learning rate, making the effective weight decay per parameter different. AdamW decouples weight decay:
This makes weight decay consistent across parameters regardless of gradient magnitude. AdamW is the default optimizer for training transformers.
Optimizer Comparison
Optimizer Selection Decision Tree
Is your model a transformer or uses batch norm?
ββ Yes β AdamW (lr=3e-4, weight_decay=0.01)
ββ No
ββ Computer Vision (CNN)?
β ββ Yes β SGD+Momentum (lr=0.1, momentum=0.9) with cosine schedule
β ββ No
β ββ Reinforcement Learning? β Adam (lr=3e-4)
β ββ General deep learning? β Start with Adam, try SGD if generalization gap
4. Learning Rate Schedules
The learning rate is the most important hyperparameter. A fixed learning rate is rarely optimal Β you want large steps early (fast convergence) and small steps later (fine-tuning).
4.1 Step Decay
Reduce the learning rate by a factor every epochs:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# lr drops to 10% at epoch 30 and 60
4.2 Cosine Annealing
Smoothly anneal from to following a cosine curve:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
Cosine annealing is the default schedule for most modern training pipelines. It provides a gentle decay at both ends and faster decay in the middle.
4.3 Warmup + Cosine
Linearly increase the learning rate from 0 to over warmup steps, then cosine anneal:
Warmup is essential for transformers Β training is unstable in early steps when parameters are random and adaptive optimizers have unreliable second-moment estimates.
4.4 OneCycle Policy
Cycles the learning rate from low β high β low within a single cycle, with momentum going in reverse (high β low β high):
where pct goes from 0 to 1 over the total training steps. Smith (2018) showed this can converge in fewer epochs than standard schedules.
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=0.01, total_steps=total_training_steps
)
Learning Rate Schedule Visualization
5. Regularization
Overfitting occurs when the model memorizes training data rather than learning generalizable patterns. Regularization techniques combat this.
5.1 Dropout
During training, each neuron is independently set to zero with probability :
The scaling (inverted dropout) ensures the expected activation remains unchanged at test time, where no dropout is applied.
Intuition: Dropout forces the network to learn redundant representations Β no single neuron can be relied upon. It can be interpreted as training an ensemble of sub-networks (where is the number of neurons).
5.2 Batch Normalization
Normalizes activations across the batch dimension for each feature:
where and over the mini-batch.
and are learnable parameters that allow the network to undo the normalization if needed.
During inference: Use running averages of and accumulated during training (via exponential moving average).
Benefits:
- Allows higher learning rates
- Reduces sensitivity to initialization
- Provides mild regularization (batch statistics add noise)
Limitation: Requires batch dimension β problematic for small batches, sequence models, or distributed training.
5.3 Layer Normalization
Normalizes across the feature dimension for each sample (independent of batch size):
5.4 Weight Decay
Adds an L2 penalty to the loss:
This pushes weights toward zero, preventing any single weight from growing too large. Typical values: .
With AdamW, weight decay is applied directly to parameters without going through the adaptive learning rate, making it more effective than L2 regularization with Adam.
6. Gradient Clipping
Exploding gradients cause numerical instability. Gradient clipping bounds the gradient norm.
Norm Clipping (recommended)
This preserves the gradient direction while limiting magnitude.
Value Clipping
Clips each gradient component independently. This changes the gradient direction and is less preferred.
# Norm clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Value clipping
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
When to use gradient clipping:
- Training RNNs/LSTMs (almost always needed)
- Training transformers (especially with large learning rates)
- Large batch training where gradient norms can spike
- Any situation with loss divergence or NaN losses
7. Mixed Precision Training
Uses 16-bit floating point (FP16) for most computations while keeping a 32-bit (FP32) master copy of weights.
Why Mixed Precision?
| FP32 | FP16 | Speedup | |
|---|---|---|---|
| Memory | 4 bytes | 2 bytes | 2Χ less memory |
| Compute (A100) | 19.5 TFLOPS | 312 TFLOPS | ~16Χ (with Tensor Cores) |
| Bandwidth | 2 TB/s | 2 TB/s | Same (but less data) |
The Problem: Loss Scaling
FP16 has a much smaller range () and precision ( decimal digits). Small gradients can underflow to zero. Solution: loss scaling Β multiply the loss by a large factor (e.g., 1024), compute gradients in this scaled space, then unscale before the update.
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
predictions = model(batch_x)
loss = criterion(predictions, batch_y)
scaler.scale(loss).backward() # backward in scaled FP16
scaler.unscale_(optimizer) # unscale gradients back to FP32
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer) # skip if gradients contain inf/nan
scaler.update() # adjust scale factor
Dynamic loss scaling: The GradScaler starts with a large scale factor and halves it whenever inf or nan gradients are detected, then increases it slowly when training is stable.
BFloat16 Alternative
BFloat16 uses 8 exponent bits (same range as FP32) and 7 mantissa bits. No loss scaling needed, but slightly less precise than FP16. Preferred on Ampere+ GPUs.
8. Distributed Training
Data Parallelism
The most common strategy: replicate the model on GPUs, split the batch across them, and average gradients:
# PyTorch DDP
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
Gradient synchronization: All-reduce communicates gradients across GPUs. NCCL (NVIDIA) or Gloo (CPU) backends. Overlap computation and communication Β start reducing gradients for layer while computing backward for layer .
Model Parallelism
When the model is too large to fit on one GPU:
- Pipeline parallelism: Split model layers across GPUs, micro-batch the pipeline
- Tensor parallelism: Split individual operations (e.g., attention heads) across GPUs
- ZeRO (DeepSpeed): Shard optimizer states, gradients, and parameters across GPUs
Training at Scale
Total batch size = num_GPUs Χ per_gpu_batch_size Χ gradient_accumulation_steps
Example: 8 GPUs Χ 32 samples Χ 4 accumulations = 1024 effective batch size
Large batch training requires adjusting the learning rate (linear scaling rule) and using warmup:
Putting It All Together: A Modern Training Recipe
# 1. Model
model = YourModel().to(device)
# 2. Optimizer (AdamW for transformers, SGD for CNNs)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
# 3. Schedule (Warmup + Cosine)
warmup_steps = 1000
total_steps = 100000
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 4. Loss
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
# 5. Mixed precision
scaler = torch.cuda.amp.GradScaler()
# 6. Training loop
for step in range(total_steps):
batch_x, batch_y = next(train_loader)
with torch.cuda.amp.autocast():
loss = criterion(model(batch_x), batch_y)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()
Hyperparameter Defaults (IIT/MIT Research Standards)
| Hyperparameter | Transformer | CNN (ResNet) |
|---|---|---|
| Optimizer | AdamW | SGD + Momentum |
| Learning rate | 3e-4 | 0.1 |
| Weight decay | 0.01 | 1e-4 |
| Batch size | 256Β2048 | 256 |
| Warmup steps | 2000Β4000 | Β |
| Schedule | Cosine | Cosine |
| Gradient clip | 1.0 | None |
| Dropout | 0.1 | 0.2 |
| Label smoothing | 0.1 | Β |
Summary
| Concept | Key Takeaway |
|---|---|
| Training loop | Forward β loss β zero_grad β backward β step |
| MSE | Regression; penalizes large errors quadratically |
| Cross-entropy | Classification; combined with log-softmax |
| Focal loss | Handles class imbalance via weighting |
| Contrastive/triplet loss | Learn embedding spaces |
| SGD + Momentum | Best for CNNs; fast convergence with proper schedule |
| Adam/AdamW | Best for transformers; adaptive per-parameter lr |
| Cosine annealing | Smooth decay; default schedule in modern training |
| Warmup | Essential for transformers; stabilizes early training |
| Dropout | Ensembles sub-networks; scale by at test time |
| BatchNorm | Normalize across batch; use in CNNs |
| LayerNorm | Normalize across features; use in transformers |
| Gradient clipping | Clip norm to prevent exploding gradients |
| Mixed precision | FP16/BF16 + loss scaling for 2-4Χ speedup |
| Distributed training | DDP for data parallelism; ZeRO for model parallelism |