πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

LLM Fine-Tuning at Scale: Distributed Training, DeepSpeed, and FSDP

Advanced LLMOpsLLM Fine-Tuning at Scale🟒 Free Lesson

Advertisement

LLM Fine-Tuning at Scale: Distributed Training, DeepSpeed, and FSDP

Fine-tuning large language models requires distributed training strategies to handle models that exceed single-GPU memory. Production systems must balance throughput, memory efficiency, and training speed.

Distributed Training Architecture

Distributed Training Strategies

1. DeepSpeed ZeRO Configuration

import json
from dataclasses import dataclass
from typing import Optional

@dataclass
class DeepSpeedConfig:
    zero_optimization: dict
    fp16: dict
    optimizer: dict
    scheduler: dict
    gradient_clipping: float = 1.0
    train_batch_size: int = 32
    train_micro_batch_size_per_gpu: int = 4
    steps_per_print: int = 100

    @classmethod
    def zeRO_stage_2(cls, learning_rate: float = 5e-5) -> "DeepSpeedConfig":
        return cls(
            zero_optimization={
                "stage": 2,
                "offload_optimizer": {"device": "cpu", "pin_memory": True},
                "allgather_partitions": True,
                "allgather_bucket_size": 5e8,
                "overlap_comm": True,
                "reduce_scatter": True,
                "reduce_bucket_size": 5e8,
                "contiguous_gradients": True
            },
            fp16={"enabled": True, "loss_scale": 0, "loss_scale_window": 1000},
            optimizer={"type": "AdamW", "params": {"lr": learning_rate}},
            scheduler={"type": "WarmupLR", "params": {"warmup_min_lr": 0, "warmup_max_lr": learning_rate}}
        )

    @classmethod
    def zeRO_stage_3(cls, learning_rate: float = 5e-5) -> "DeepSpeedConfig":
        return cls(
            zero_optimization={
                "stage": 3,
                "offload_optimizer": {"device": "cpu", "pin_memory": True},
                "offload_param": {"device": "cpu", "pin_memory": True},
                "overlap_comm": True,
                "stage3_max_live_parameters": 1e9,
                "stage3_max_reuse_distance": 1e9,
                "stage3_prefetch_bucket_size": 5e7,
                "stage3_param_persistence_threshold": 1e6,
                "reduce_bucket_size": 5e8,
                "contiguous_gradients": True
            },
            fp16={"enabled": True, "loss_scale": 0},
            optimizer={"type": "AdamW", "params": {"lr": learning_rate}},
            scheduler={"type": "WarmupLR", "params": {"warmup_min_lr": 0, "warmup_max_lr": learning_rate}}
        )

    def to_json(self, filepath: str):
        config = {
            "zero_optimization": self.zero_optimization,
            "fp16": self.fp16,
            "optimizer": self.optimizer,
            "scheduler": self.scheduler,
            "gradient_clipping": self.gradient_clipping,
            "train_batch_size": self.train_batch_size,
            "train_micro_batch_size_per_gpu": self.train_micro_batch_size_per_gpu,
            "steps_per_print": self.steps_per_print
        }
        with open(filepath, "w") as f:
            json.dump(config, f, indent=2)

2. FSDP Wrapper

import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision

class FSDPTrainer:
    def __init__(self, model: nn.Module, world_size: int = 4):
        self.model = model
        self.world_size = world_size
        self.setup_fsdp()

    def setup_fsdp(self):
        mp_policy = MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
            buffer_dtype=torch.bfloat16
        )
        self.model = FSDP(
            self.model,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            mixed_precision=mp_policy,
            device_id=torch.cuda.current_device(),
            use_orig_params=True
        )

    def compute_memory_per_gpu(self, model_params: int, world_size: int) -> dict:
        param_memory = model_params * 2 / world_size
        grad_memory = model_params * 2 / world_size
        optimizer_memory = model_params * 8 / world_size
        total = param_memory + grad_memory + optimizer_memory
        return {
            "params_gb": param_memory / (1024**3),
            "grads_gb": grad_memory / (1024**3),
            "optimizer_gb": optimizer_memory / (1024**3),
            "total_gb": total / (1024**3)
        }

Key Formulas

Memory Requirement per GPU

MGPU=Mparams+Mgrads+MoptimNGPUsM_{GPU} = \frac{M_{params} + M_{grads} + M_{optim}}{N_{GPUs}}

Here,

  • MparamsM_{params}=Model parameter memory
  • MgradsM_{grads}=Gradient memory
  • MoptimM_{optim}=Optimizer state memory (Adam: 8 bytes/param)
  • NGPUsN_{GPUs}=Number of GPUs

Communication Cost (AllReduce)

Tcomm=2(Nβˆ’1)Nβ‹…MgradsBbandwidthT_{comm} = \frac{2(N-1)}{N} \cdot \frac{M_{grads}}{B_{bandwidth}}

Here,

  • NN=Number of GPUs
  • MgradsM_{grads}=Gradient message size
  • BbandwidthB_{bandwidth}=Inter-GPU bandwidth

ZeRO Stage Comparison

StageOptimizer StatesGradientsParametersMemory Savings
Stage 0ReplicatedReplicatedReplicated1x
Stage 1PartitionedReplicatedReplicated~4x
Stage 2PartitionedPartitionedReplicated~8x
Stage 3PartitionedPartitionedPartitioned~N*x

Best Practices

  1. Use mixed precision (bf16/fp16) to reduce memory and increase throughput
  2. Start with ZeRO Stage 2 for most fine-tuning tasks
  3. Gradient checkpointing trades compute for memory at ~30% speed loss
  4. Monitor GPU utilization with nvidia-smi during training
  5. Save checkpoints frequently since distributed training can fail at any rank
⭐

Premium Content

LLM Fine-Tuning at Scale: Distributed Training, DeepSpeed, and FSDP

Unlock this lesson and 900+ advanced tutorials with a Premium plan.

🎯End-to-end Projects
πŸ’ΌInterview Prep
πŸ“œCertificates
🀝Community Access

Already a member? Log in

Need Expert AI Ops & LLM Ops Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement