LLM Fine-Tuning at Scale: Distributed Training, DeepSpeed, and FSDP
Fine-tuning large language models requires distributed training strategies to handle models that exceed single-GPU memory. Production systems must balance throughput, memory efficiency, and training speed.
Distributed Training Architecture
Distributed Training Strategies
1. DeepSpeed ZeRO Configuration
import json
from dataclasses import dataclass
from typing import Optional
@dataclass
class DeepSpeedConfig:
zero_optimization: dict
fp16: dict
optimizer: dict
scheduler: dict
gradient_clipping: float = 1.0
train_batch_size: int = 32
train_micro_batch_size_per_gpu: int = 4
steps_per_print: int = 100
@classmethod
def zeRO_stage_2(cls, learning_rate: float = 5e-5) -> "DeepSpeedConfig":
return cls(
zero_optimization={
"stage": 2,
"offload_optimizer": {"device": "cpu", "pin_memory": True},
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True
},
fp16={"enabled": True, "loss_scale": 0, "loss_scale_window": 1000},
optimizer={"type": "AdamW", "params": {"lr": learning_rate}},
scheduler={"type": "WarmupLR", "params": {"warmup_min_lr": 0, "warmup_max_lr": learning_rate}}
)
@classmethod
def zeRO_stage_3(cls, learning_rate: float = 5e-5) -> "DeepSpeedConfig":
return cls(
zero_optimization={
"stage": 3,
"offload_optimizer": {"device": "cpu", "pin_memory": True},
"offload_param": {"device": "cpu", "pin_memory": True},
"overlap_comm": True,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_prefetch_bucket_size": 5e7,
"stage3_param_persistence_threshold": 1e6,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True
},
fp16={"enabled": True, "loss_scale": 0},
optimizer={"type": "AdamW", "params": {"lr": learning_rate}},
scheduler={"type": "WarmupLR", "params": {"warmup_min_lr": 0, "warmup_max_lr": learning_rate}}
)
def to_json(self, filepath: str):
config = {
"zero_optimization": self.zero_optimization,
"fp16": self.fp16,
"optimizer": self.optimizer,
"scheduler": self.scheduler,
"gradient_clipping": self.gradient_clipping,
"train_batch_size": self.train_batch_size,
"train_micro_batch_size_per_gpu": self.train_micro_batch_size_per_gpu,
"steps_per_print": self.steps_per_print
}
with open(filepath, "w") as f:
json.dump(config, f, indent=2)
2. FSDP Wrapper
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision
class FSDPTrainer:
def __init__(self, model: nn.Module, world_size: int = 4):
self.model = model
self.world_size = world_size
self.setup_fsdp()
def setup_fsdp(self):
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16
)
self.model = FSDP(
self.model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mp_policy,
device_id=torch.cuda.current_device(),
use_orig_params=True
)
def compute_memory_per_gpu(self, model_params: int, world_size: int) -> dict:
param_memory = model_params * 2 / world_size
grad_memory = model_params * 2 / world_size
optimizer_memory = model_params * 8 / world_size
total = param_memory + grad_memory + optimizer_memory
return {
"params_gb": param_memory / (1024**3),
"grads_gb": grad_memory / (1024**3),
"optimizer_gb": optimizer_memory / (1024**3),
"total_gb": total / (1024**3)
}
Key Formulas
Memory Requirement per GPU
Here,
- =Model parameter memory
- =Gradient memory
- =Optimizer state memory (Adam: 8 bytes/param)
- =Number of GPUs
Communication Cost (AllReduce)
Here,
- =Number of GPUs
- =Gradient message size
- =Inter-GPU bandwidth
ZeRO Stage Comparison
| Stage | Optimizer States | Gradients | Parameters | Memory Savings |
|---|---|---|---|---|
| Stage 0 | Replicated | Replicated | Replicated | 1x |
| Stage 1 | Partitioned | Replicated | Replicated | ~4x |
| Stage 2 | Partitioned | Partitioned | Replicated | ~8x |
| Stage 3 | Partitioned | Partitioned | Partitioned | ~N*x |
Best Practices
- Use mixed precision (bf16/fp16) to reduce memory and increase throughput
- Start with ZeRO Stage 2 for most fine-tuning tasks
- Gradient checkpointing trades compute for memory at ~30% speed loss
- Monitor GPU utilization with nvidia-smi during training
- Save checkpoints frequently since distributed training can fail at any rank