Chapter 7 — Model Training Infrastructure and Distributed Training
Learning Objectives
Describe the GPU memory hierarchy (registers, HBM, host RAM, NVMe) and explain why intra-node NVLink versus inter-node InfiniBand bandwidth shapes every distributed training decision.
Compare data, tensor, and pipeline parallelism and explain how ZeRO/FSDP sharding combines with them to form 3D parallelism for frontier LLM training.
Choose between PyTorch DDP, FSDP, Horovod, DeepSpeed, and Ray Train based on model size, framework constraints, and orchestration needs.
Apply cost levers including BF16/FP16 mixed precision, gradient checkpointing, gradient accumulation, and elastic torchrun on spot instances to slash training cost without harming convergence.
Use PyTorch Profiler, Nsight Systems, and DCGM to identify whether a job is compute-bound, memory-bound, or communication-bound.
Section 1: The Training Compute Landscape
Pre-Reading Check — Training Compute
1. Which fact about GPU bandwidth is the single most important constraint shaping distributed training topology?
Host RAM is always slower than NVMe storage on modern servers.Intra-node NVLink bandwidth dramatically exceeds inter-node InfiniBand or Ethernet bandwidth.All GPU collectives must traverse the host CPU before reaching another GPU.PCIe bandwidth between two GPUs in the same chassis is higher than NVLink.
2. Why are TPUs particularly well suited to large batched transformer workloads?
They expose a CUDA-compatible API for direct PyTorch use without compilation.They are essentially giant systolic arrays optimized for dense linear algebra, paired with XLA and JAX.They use higher-precision FP64 arithmetic that improves transformer convergence.They share host memory directly with the CPU through unified memory.
3. Which workload most clearly does NOT need a GPU?
Pretraining a 13B-parameter transformer on 1T tokens.A tabular XGBoost model on tens of millions of rows.Fine-tuning a vision transformer at high resolution.Training a multi-task BERT model with 24 layers.
Training a modern deep learning model is, in many ways, less about clever algorithms and more about choreographing thousands of arithmetic units, gigabytes of memory, and miles of high-speed cabling. The model is the recipe; the infrastructure is the kitchen.
GPUs and the Memory Hierarchy
NVIDIA's V100, A100, and H100 GPUs dominate production training. The A100 introduced hardware-accelerated BF16; the H100 added FP8 support and roughly 3x the matrix-math throughput per watt for transformer workloads. GPU memory has a steep hierarchy: registers (kilobytes, single-digit-ns latency), L2 cache (tens of MB), HBM (40-80 GB on data center cards with terabytes/sec bandwidth), then host RAM over PCIe, and finally NVMe storage. Every byte that moves down this hierarchy costs time.
Inside a node, GPUs talk over NVLink and NVSwitch at hundreds of GB/s. Across nodes, clusters use InfiniBand or 100–400 Gbps Ethernet. The dramatic gap between intra-node and inter-node bandwidth is the single most important fact about distributed training topology.
TPUs, Other Accelerators, and When CPUs Still Win
Google TPUs are 2D systolic arrays of multiply-accumulate units optimized for batched matmuls and BF16; they pair best with XLA and JAX. Other accelerators (AWS Trainium, Cerebras, Graphcore, Habana) have niches but NVIDIA GPUs and TPUs dominate large-scale training. CPUs remain optimal for tabular GBMs (XGBoost, LightGBM), classical scikit-learn pipelines, lightweight NLP fine-tunes, and feature engineering: if data movement dominates compute, a CPU is fine.
Cluster Managers
Three families dominate ML training clusters:
Kubernetes (with Kubeflow Training Operator, Volcano, or KubeRay) — great for hybrid workloads sharing infrastructure.
Slurm — the HPC workload manager, excellent at gang scheduling and MPI; common in academic and pretraining clusters.
Ray — a Python-first distributed framework with Ray Train, Ray Tune, and Ray Serve.
Large organizations often layer these: Slurm/K8s at the bottom managing GPU allocations, Ray or a Kubeflow PyTorchJob on top to launch the training script.
Figure 7.1: GPU memory hierarchy and cluster wiring
GPUs are throughput machines; A100 enables hardware BF16, H100 adds FP8 and roughly 3x matmul throughput per watt.
The GPU memory hierarchy (registers → L2 → HBM → host RAM → NVMe) dictates that minimizing data movement is the highest-leverage optimization.
Intra-node NVLink bandwidth massively exceeds inter-node InfiniBand — the foundation rule "keep tensor parallelism within a node."
TPUs are systolic arrays optimized for batched matmuls and BF16, ideal with XLA/JAX but require giving up the CUDA ecosystem.
CPUs still win for tabular models, small NLP, and orchestration; Kubernetes, Slurm, and Ray are the three dominant cluster managers and are commonly layered.
Post-Reading Check — Training Compute
1. Which fact about GPU bandwidth is the single most important constraint shaping distributed training topology?
Host RAM is always slower than NVMe storage on modern servers.Intra-node NVLink bandwidth dramatically exceeds inter-node InfiniBand or Ethernet bandwidth.All GPU collectives must traverse the host CPU before reaching another GPU.PCIe bandwidth between two GPUs in the same chassis is higher than NVLink.
2. Why are TPUs particularly well suited to large batched transformer workloads?
They expose a CUDA-compatible API for direct PyTorch use without compilation.They are essentially giant systolic arrays optimized for dense linear algebra, paired with XLA and JAX.They use higher-precision FP64 arithmetic that improves transformer convergence.They share host memory directly with the CPU through unified memory.
3. Which workload most clearly does NOT need a GPU?
Pretraining a 13B-parameter transformer on 1T tokens.A tabular XGBoost model on tens of millions of rows.Fine-tuning a vision transformer at high resolution.Training a multi-task BERT model with 24 layers.
Section 2: Distributed Training Strategies
Pre-Reading Check — Parallelism
1. In standard data parallelism, what is communicated between GPUs each step?
Activations at every layer boundary.Gradients, averaged across replicas via an all-reduce.Parameters, gathered each forward pass via all-gather.Random number generator state.
2. What does ZeRO Stage 3 (FSDP) shard across data-parallel ranks that ZeRO Stage 1 does not?
Only optimizer states.Only gradients.Optimizer states, gradients, and parameters.Activation tensors only.
3. Why is tensor parallelism typically kept WITHIN a single node?
PyTorch only supports tensor parallelism on a single node.Tensor parallelism triggers fine-grained collectives per layer, which need NVLink-level bandwidth.Inter-node InfiniBand is faster than intra-node NVLink.Tensor parallelism cannot use NCCL across nodes.
Data Parallelism
Every GPU holds a complete model replica; the global batch is split into per-GPU mini-batches; each replica runs forward/backward; then gradients are all-reduced across ranks. Efficient implementations overlap the gradient communication with the backward pass so the sync is mostly hidden. DP scales near-linearly to dozens of GPUs as long as the full model (params + optimizer state + activations) fits on one device.
Model and Tensor Parallelism
When the model exceeds one GPU's memory, partition it. Layer-wise model parallelism places different layers on different GPUs. Tensor parallelism splits weight matrices inside a layer (e.g., Megatron-LM splits attention heads and MLP intermediate dimensions); the partial outputs are combined with all-gather or all-reduce. Tensor parallelism is communication-heavy and should live within an NVLink island.
Pipeline Parallelism
The model is divided into sequential stages; microbatches flow through them assembly-line style. The challenge is pipeline bubbles at fill/drain; the remedies are M >> S microbatches and smarter schedules like 1F1B (one-forward-one-backward) instead of GPipe.
ZeRO Sharding and FSDP
ZeRO shards training state across data-parallel ranks. ZeRO-1 shards optimizer states; ZeRO-2 adds gradient sharding; ZeRO-3 shards parameters too. PyTorch FSDP is the native ZeRO-3 implementation: an all-gather assembles each layer's params before forward, then a reduce-scatter after backward leaves each rank with only its shard of the gradient. Optimizer step updates the local shard only.
NCCL Ring All-Reduce
NCCL is NVIDIA's collective library (all-reduce, all-gather, reduce-scatter, broadcast), choosing between ring, tree, and hierarchical algorithms based on topology. Horovod popularized the bandwidth-optimal ring all-reduce: N GPUs in a logical ring, tensor split into N chunks; a scatter-reduce phase circulates chunks while each GPU adds its contribution, then an all-gather phase circulates the reduced chunks so every GPU ends with the full result.
flowchart TD
Batch[Global Mini-Batch] --> Split{Split Across N GPUs}
Split --> S1[Shard 1]
Split --> S2[Shard 2]
Split --> S3[Shard 3]
Split --> S4[Shard 4]
S1 --> G1[GPU 1: Full Model Replica]
S2 --> G2[GPU 2: Full Model Replica]
S3 --> G3[GPU 3: Full Model Replica]
S4 --> G4[GPU 4: Full Model Replica]
G1 --> AR[NCCL All-Reduce: Average Gradients]
G2 --> AR
G3 --> AR
G4 --> AR
AR --> U1[Identical Optimizer Step on Every Replica]
Figure 7.3: ZeRO sharding stages — progressive partitioning of training state
graph TD
Base[Baseline DDP: Every Rank Holds Full State] --> Z1
subgraph Z1[ZeRO-1]
Z1A[Params: Replicated]
Z1B[Gradients: Replicated]
Z1C[Optimizer States: SHARDED]
end
Z1 --> Z2
subgraph Z2[ZeRO-2]
Z2A[Params: Replicated]
Z2B[Gradients: SHARDED]
Z2C[Optimizer States: SHARDED]
end
Z2 --> Z3
subgraph Z3[ZeRO-3 / FSDP]
Z3A[Params: SHARDED]
Z3B[Gradients: SHARDED]
Z3C[Optimizer States: SHARDED]
end
Z3 --> Off[+ CPU/NVMe Offload: Push State Beyond GPU RAM]
Figure 7.4: NCCL ring all-reduce — bandwidth-optimal aggregation
Animation A1 — Data Parallelism with Ring All-Reduce
Each GPU receives a different batch shard (amber), computes locally on its full model copy, then gradients flow around the ring (blue) so every replica ends with the averaged gradient before the optimizer step.
Animation A2 — ZeRO/FSDP Progressive Sharding
As you climb ZeRO stages, more of each row "sharded" pulses in — first optimizer states (Stage 1), then gradients (Stage 2), then parameters (Stage 3 / FSDP). Total per-GPU memory drops toward 1/N of the model.
Comparing Parallelism Strategies
Aspect
Data Parallel
Tensor / Model
Pipeline
ZeRO-3 / FSDP
Model fits on one GPU?
Required
Not required
Not required
Not required
Main memory benefit
None for params
Per-GPU param mem shrunk
Per-stage params + acts
All states sharded
Communication pattern
All-reduce per step
Collectives inside layers
Activations at boundaries
All-gather + reduce-scatter per layer
Communication frequency
Once per iteration
Per TP layer
Per microbatch per boundary
Per layer
GPU utilization
Typically high
High if layer is large
Limited by bubbles
High with prefetching
Complexity
Easiest
High
High
Moderate
Best scale axis
Batch size
Layer width
Model depth
Total parameter count
Frontier LLM training combines all three into 3D parallelism: tensor parallel within a node, pipeline parallel across nodes, data parallel across replicas.
Key Points — Parallelism
Data parallelism replicates the full model and all-reduces gradients each step — simplest, requires model to fit on one GPU.
Tensor parallelism shards weight matrices inside a layer and must be kept within an NVLink island due to per-layer collectives.
Pipeline parallelism scales model depth via microbatches through sequential stages; mitigate bubbles with M >> S and 1F1B scheduling.
ZeRO stages 1/2/3 progressively shard optimizer state, then gradients, then parameters; FSDP is PyTorch's native ZeRO-3.
NCCL ring all-reduce is bandwidth-optimal but latency-bound at large N; modern LLM training combines DP, TP, PP, and ZeRO into 3D parallelism.
Post-Reading Check — Parallelism
1. In standard data parallelism, what is communicated between GPUs each step?
Activations at every layer boundary.Gradients, averaged across replicas via an all-reduce.Parameters, gathered each forward pass via all-gather.Random number generator state.
2. What does ZeRO Stage 3 (FSDP) shard across data-parallel ranks that ZeRO Stage 1 does not?
Only optimizer states.Only gradients.Optimizer states, gradients, and parameters.Activation tensors only.
3. Why is tensor parallelism typically kept WITHIN a single node?
PyTorch only supports tensor parallelism on a single node.Tensor parallelism triggers fine-grained collectives per layer, which need NVLink-level bandwidth.Inter-node InfiniBand is faster than intra-node NVLink.Tensor parallelism cannot use NCCL across nodes.
Section 3: Frameworks and Tools for Distributed Training
Pre-Reading Check — Frameworks
1. When should you migrate from PyTorch DDP to FSDP?
Whenever you have more than 4 GPUs, regardless of model size.When the model + optimizer state no longer fits on one GPU and OOM appears.When you need stronger gradient accumulation than DDP allows.Only when training across heterogeneous frameworks like TensorFlow.
2. What capability distinguishes DeepSpeed from pure FSDP for the very largest models?
DeepSpeed supports BF16 while FSDP does not.CPU/NVMe offload of parameters, gradients, and optimizer states beyond GPU RAM, plus integrated pipeline and Megatron-LM tensor parallelism.DeepSpeed uses MPI instead of NCCL for collectives.DeepSpeed runs without a launcher.
3. Which best describes Ray Train's role?
A new gradient compression algorithm replacing NCCL.An orchestration and abstraction layer that wraps DDP/FSDP/DeepSpeed behind a uniform Python API and handles scaling, fault tolerance, and checkpoints.A drop-in replacement for NCCL on TPUs.A standalone tensor parallelism implementation.
PyTorch DDP
DDP is the canonical PyTorch data-parallel implementation: each rank holds a full copy of parameters, gradients, and optimizer state; gradients are bucketed and all-reduced via NCCL during backward. Launched with torchrun, which spawns one process per GPU. The right tool for models that fit comfortably on one GPU scaling up to 8–16 GPUs.
PyTorch FSDP
Switch to FSDP when the model no longer fits. ZeRO-3-style full sharding: all-gather before each layer's forward, reduce-scatter after backward. More NCCL calls than DDP but per-rank memory drops by roughly 1/N. The tricky part is the auto-wrap policy — wrap each transformer block. Integrates with torch.distributed.checkpoint for sharded checkpoint I/O.
Horovod
Framework-agnostic ring-allreduce library originally from Uber, works with PyTorch, TF, and MXNet. Default in heterogeneous shops circa 2018–2020. In PyTorch-only environments, DDP eclipsed it for the same use cases; FSDP and DeepSpeed eclipsed it for memory-bound workloads.
DeepSpeed and Megatron-LM
Microsoft's batteries-included stack for very large training. Adds:
CPU and NVMe offload for parameters, gradients, and optimizer states — train models beyond GPU RAM.
Activation checkpointing integrated with the optimizer.
Pipeline parallelism with multiple scheduling strategies.
Tensor parallelism via integration with NVIDIA's Megatron-LM (Megatron-DeepSpeed).
Configuration overhead is real: a JSON config with dozens of knobs.
Ray Train
An orchestration layer: you write a single training function using DDP/FSDP/DeepSpeed internally, then ask Ray to run it on a cluster. Ray handles scaling, fault tolerance, Ray Tune integration, and checkpoint management. Abstracts away the launcher (torchrun vs deepspeed vs horovodrun).
Migration Path
DDP → FSDP when OOM hits. FSDP → DeepSpeed when you need offload, pipeline parallel, or Megatron-LM TP.
Figure 7.5: Framework decision flow
flowchart TD
Start[Start: Pick a Distributed Training Framework] --> Q1{Model fits on 1 GPU with optimizer state?}
Q1 -->|Yes, ≤16 GPUs| DDP[PyTorch DDP]
Q1 -->|No| Q2{Need CPU/NVMe offload or 3D parallelism?}
Q2 -->|No - pure PyTorch is fine| FSDP[PyTorch FSDP - ZeRO-3]
Q2 -->|Yes| DS[DeepSpeed + Megatron-LM]
Q1 -->|Multi-framework / MPI shop| Horo[Horovod]
DDP --> Orch{Need cloud-native orchestration, HPO, fault tolerance?}
FSDP --> Orch
DS --> Orch
Horo --> Orch
Orch -->|Yes| Ray[Wrap with Ray Train]
Orch -->|No| Done[Launch with native launcher]
Key Points — Frameworks
PyTorch DDP is the default for small/medium models up to roughly 16 GPUs — launched with torchrun.
FSDP brings ZeRO-3 sharding into pure PyTorch; pick a transformer-block auto-wrap policy for best balance of memory and communication.
DeepSpeed adds CPU/NVMe offload, integrated pipeline parallel, and Megatron-LM tensor parallelism — reach for it when training 70B+ parameter models.
Horovod is mainly relevant in multi-framework or MPI-heavy shops; in PyTorch-only environments it has been eclipsed by DDP and FSDP.
Ray Train is an orchestration layer wrapping all of the above behind a uniform Python API with HPO and fault tolerance built in.
Post-Reading Check — Frameworks
1. When should you migrate from PyTorch DDP to FSDP?
Whenever you have more than 4 GPUs, regardless of model size.When the model + optimizer state no longer fits on one GPU and OOM appears.When you need stronger gradient accumulation than DDP allows.Only when training across heterogeneous frameworks like TensorFlow.
2. What capability distinguishes DeepSpeed from pure FSDP for the very largest models?
DeepSpeed supports BF16 while FSDP does not.CPU/NVMe offload of parameters, gradients, and optimizer states beyond GPU RAM, plus integrated pipeline and Megatron-LM tensor parallelism.DeepSpeed uses MPI instead of NCCL for collectives.DeepSpeed runs without a launcher.
3. Which best describes Ray Train's role?
A new gradient compression algorithm replacing NCCL.An orchestration and abstraction layer that wraps DDP/FSDP/DeepSpeed behind a uniform Python API and handles scaling, fault tolerance, and checkpoints.A drop-in replacement for NCCL on TPUs.A standalone tensor parallelism implementation.
Section 4: Cost and Throughput Optimization
Pre-Reading Check — Cost
1. Why does BF16 typically NOT need dynamic loss scaling while FP16 does?
BF16 has more mantissa bits than FP16.BF16 has the same exponent range as FP32 (8 bits), so gradients rarely underflow.BF16 is computed entirely on CPU.BF16 only works during inference.
2. Gradient accumulation is most useful for which situation?
Maintaining a target global batch size after elastic shrinkage from 8 to 4 nodes.Increasing GPU SM utilization above 100%.Replacing the need for an optimizer.Reducing total compute by skipping micro-batches.
3. What does torchrun --nnodes=2:8 --max-restarts=5 enable?
Static training with exactly 5 nodes.Elastic training that tolerates node preemptions, re-rendezvousing between 2 and 8 nodes with up to 5 restarts.A fixed 2-node job that runs 5 times for hyperparameter search.A debug mode that disables NCCL.
Mixed Precision and BF16
Tensor cores run 16-bit matmuls at 4–8x FP32 throughput. FP16 has only 5 exponent bits, so gradients can underflow — AMP with a GradScaler multiplies the loss before backward and unscales before the optimizer step. BF16 uses 8 exponent bits (FP32 range) at the cost of mantissa bits, and usually does not need loss scaling. A100/H100 have first-class BF16; H100 adds FP8.
Gradient Checkpointing and Accumulation
Gradient checkpointing drops most activations in forward and recomputes them in backward — 30–50% activation memory savings at 20–40% extra compute. Gradient accumulation sums gradients over accum_steps micro-batches before stepping, simulating larger global batch without raising per-GPU memory. Particularly useful when spot preemption forces the job from 8 nodes down to 4 — bump accumulation steps to maintain the same effective global batch.
Spot, Preemptible, and Elastic Training
Cloud spot/preemptible instances offer 50–80% discounts but can be reclaimed with short notice. The key is elastic, fault-tolerant launch via torchrun: --nnodes=2:8 --max-restarts=5 --rdzv-backend=c10d. Membership changes become routine, not fatal.
For elastic to work, the script must be stateless across restarts:
A single main() entry point (no mp.spawn).
On startup, load the latest snapshot from durable storage (S3/GCS/NFS) before init_process_group.
Snapshots must include model, optimizer, AMP GradScaler, LR scheduler, global step/epoch, and RNG seeds.
Use DistributedSampler and call sampler.set_epoch(epoch).
Snapshot every 5–15 minutes; install a 2-minute termination-notice handler for emergency checkpoints.
sequenceDiagram
participant Cloud as Cloud Spot Pool
participant TR as torchrun (elastic)
participant Job as Training Workers
participant S3 as Durable Storage (S3/GCS)
TR->>Job: Rendezvous N nodes, start training
Job->>S3: Snapshot every 5-15 min (model+opt+scaler+RNG)
Cloud-->>Job: 2-minute preemption notice on Node K
Job->>S3: Emergency snapshot (termination handler)
Cloud-->>Job: Node K reclaimed
TR->>TR: Detect membership change, re-rendezvous
TR->>Job: Resume with N-1 nodes (within --nnodes=min:max)
Job->>S3: Load latest snapshot
Job->>Job: Resume from saved global step + RNG state
Cloud-->>TR: New spot capacity available
TR->>Job: Re-rendezvous, scale back up to N
Animation A3 — Mixed Precision Flow with Loss Scaling
Master weights stay FP32 (blue). They cast to FP16/BF16 (amber) for tensor-core forward and backward. FP16 multiplies loss by a scale factor before backward and unscales gradients before the optimizer step on the FP32 master — BF16 skips both because its FP32-range exponent avoids underflow.
Key Points — Cost
BF16 AMP on A100/H100 typically delivers 1.5–2x speedup over FP32 with no loss-scaling complexity; FP16 needs a GradScaler to avoid gradient underflow.
Gradient checkpointing trades 20–40% extra compute for 30–50% activation memory savings — the difference between OOM and training.
Gradient accumulation maintains a target global batch under elastic shrinkage without raising per-GPU memory.
Elastic torchrun (--nnodes=min:max --max-restarts=N) plus robust 5–15-minute snapshots makes spot instances safe and cuts cost 50–80%.
Profile with PyTorch Profiler, Nsight Systems, and DCGM; target SM utilization > 70% and NCCL time < 30% of step. Optimize for cost per token, not GPU-hour price.
Post-Reading Check — Cost
1. Why does BF16 typically NOT need dynamic loss scaling while FP16 does?
BF16 has more mantissa bits than FP16.BF16 has the same exponent range as FP32 (8 bits), so gradients rarely underflow.BF16 is computed entirely on CPU.BF16 only works during inference.
2. Gradient accumulation is most useful for which situation?
Maintaining a target global batch size after elastic shrinkage from 8 to 4 nodes.Increasing GPU SM utilization above 100%.Replacing the need for an optimizer.Reducing total compute by skipping micro-batches.
3. What does torchrun --nnodes=2:8 --max-restarts=5 enable?
Static training with exactly 5 nodes.Elastic training that tolerates node preemptions, re-rendezvousing between 2 and 8 nodes with up to 5 restarts.A fixed 2-node job that runs 5 times for hyperparameter search.A debug mode that disables NCCL.