Chapter 7 — Model Training Infrastructure and Distributed Training

Learning Objectives

Section 1: The Training Compute Landscape

Pre-Reading Check — Training Compute

1. Which fact about GPU bandwidth is the single most important constraint shaping distributed training topology?

Host RAM is always slower than NVMe storage on modern servers. Intra-node NVLink bandwidth dramatically exceeds inter-node InfiniBand or Ethernet bandwidth. All GPU collectives must traverse the host CPU before reaching another GPU. PCIe bandwidth between two GPUs in the same chassis is higher than NVLink.

2. Why are TPUs particularly well suited to large batched transformer workloads?

They expose a CUDA-compatible API for direct PyTorch use without compilation. They are essentially giant systolic arrays optimized for dense linear algebra, paired with XLA and JAX. They use higher-precision FP64 arithmetic that improves transformer convergence. They share host memory directly with the CPU through unified memory.

3. Which workload most clearly does NOT need a GPU?

Pretraining a 13B-parameter transformer on 1T tokens. A tabular XGBoost model on tens of millions of rows. Fine-tuning a vision transformer at high resolution. Training a multi-task BERT model with 24 layers.

Training a modern deep learning model is, in many ways, less about clever algorithms and more about choreographing thousands of arithmetic units, gigabytes of memory, and miles of high-speed cabling. The model is the recipe; the infrastructure is the kitchen.

GPUs and the Memory Hierarchy

NVIDIA's V100, A100, and H100 GPUs dominate production training. The A100 introduced hardware-accelerated BF16; the H100 added FP8 support and roughly 3x the matrix-math throughput per watt for transformer workloads. GPU memory has a steep hierarchy: registers (kilobytes, single-digit-ns latency), L2 cache (tens of MB), HBM (40-80 GB on data center cards with terabytes/sec bandwidth), then host RAM over PCIe, and finally NVMe storage. Every byte that moves down this hierarchy costs time.

Inside a node, GPUs talk over NVLink and NVSwitch at hundreds of GB/s. Across nodes, clusters use InfiniBand or 100–400 Gbps Ethernet. The dramatic gap between intra-node and inter-node bandwidth is the single most important fact about distributed training topology.

TPUs, Other Accelerators, and When CPUs Still Win

Google TPUs are 2D systolic arrays of multiply-accumulate units optimized for batched matmuls and BF16; they pair best with XLA and JAX. Other accelerators (AWS Trainium, Cerebras, Graphcore, Habana) have niches but NVIDIA GPUs and TPUs dominate large-scale training. CPUs remain optimal for tabular GBMs (XGBoost, LightGBM), classical scikit-learn pipelines, lightweight NLP fine-tunes, and feature engineering: if data movement dominates compute, a CPU is fine.

Cluster Managers

Three families dominate ML training clusters:

Large organizations often layer these: Slurm/K8s at the bottom managing GPU allocations, Ray or a Kubeflow PyTorchJob on top to launch the training script.

Figure 7.1: GPU memory hierarchy and cluster wiring

flowchart TD Reg[Registers / Shared Mem
kB, ~ns] L2[L2 Cache
tens of MB] HBM[HBM
40-80 GB, TB/s] Host[Host RAM
hundreds of GB, PCIe] NVMe[NVMe SSD
TB, slowest] Reg --> L2 --> HBM --> Host --> NVMe HBM -.NVLink/NVSwitch hundreds GB/s.- HBM2[Peer GPU HBM] Host -.InfiniBand/Eth 100-400 Gbps.- RemoteHost[Remote Node]

Key Points — Training Compute

Post-Reading Check — Training Compute

1. Which fact about GPU bandwidth is the single most important constraint shaping distributed training topology?

Host RAM is always slower than NVMe storage on modern servers. Intra-node NVLink bandwidth dramatically exceeds inter-node InfiniBand or Ethernet bandwidth. All GPU collectives must traverse the host CPU before reaching another GPU. PCIe bandwidth between two GPUs in the same chassis is higher than NVLink.

2. Why are TPUs particularly well suited to large batched transformer workloads?

They expose a CUDA-compatible API for direct PyTorch use without compilation. They are essentially giant systolic arrays optimized for dense linear algebra, paired with XLA and JAX. They use higher-precision FP64 arithmetic that improves transformer convergence. They share host memory directly with the CPU through unified memory.

3. Which workload most clearly does NOT need a GPU?

Pretraining a 13B-parameter transformer on 1T tokens. A tabular XGBoost model on tens of millions of rows. Fine-tuning a vision transformer at high resolution. Training a multi-task BERT model with 24 layers.

Section 2: Distributed Training Strategies

Pre-Reading Check — Parallelism

1. In standard data parallelism, what is communicated between GPUs each step?

Activations at every layer boundary. Gradients, averaged across replicas via an all-reduce. Parameters, gathered each forward pass via all-gather. Random number generator state.

2. What does ZeRO Stage 3 (FSDP) shard across data-parallel ranks that ZeRO Stage 1 does not?

Only optimizer states. Only gradients. Optimizer states, gradients, and parameters. Activation tensors only.

3. Why is tensor parallelism typically kept WITHIN a single node?

PyTorch only supports tensor parallelism on a single node. Tensor parallelism triggers fine-grained collectives per layer, which need NVLink-level bandwidth. Inter-node InfiniBand is faster than intra-node NVLink. Tensor parallelism cannot use NCCL across nodes.

Data Parallelism

Every GPU holds a complete model replica; the global batch is split into per-GPU mini-batches; each replica runs forward/backward; then gradients are all-reduced across ranks. Efficient implementations overlap the gradient communication with the backward pass so the sync is mostly hidden. DP scales near-linearly to dozens of GPUs as long as the full model (params + optimizer state + activations) fits on one device.

Model and Tensor Parallelism

When the model exceeds one GPU's memory, partition it. Layer-wise model parallelism places different layers on different GPUs. Tensor parallelism splits weight matrices inside a layer (e.g., Megatron-LM splits attention heads and MLP intermediate dimensions); the partial outputs are combined with all-gather or all-reduce. Tensor parallelism is communication-heavy and should live within an NVLink island.

Pipeline Parallelism

The model is divided into sequential stages; microbatches flow through them assembly-line style. The challenge is pipeline bubbles at fill/drain; the remedies are M >> S microbatches and smarter schedules like 1F1B (one-forward-one-backward) instead of GPipe.

ZeRO Sharding and FSDP

ZeRO shards training state across data-parallel ranks. ZeRO-1 shards optimizer states; ZeRO-2 adds gradient sharding; ZeRO-3 shards parameters too. PyTorch FSDP is the native ZeRO-3 implementation: an all-gather assembles each layer's params before forward, then a reduce-scatter after backward leaves each rank with only its shard of the gradient. Optimizer step updates the local shard only.

NCCL Ring All-Reduce

NCCL is NVIDIA's collective library (all-reduce, all-gather, reduce-scatter, broadcast), choosing between ring, tree, and hierarchical algorithms based on topology. Horovod popularized the bandwidth-optimal ring all-reduce: N GPUs in a logical ring, tensor split into N chunks; a scatter-reduce phase circulates chunks while each GPU adds its contribution, then an all-gather phase circulates the reduced chunks so every GPU ends with the full result.

Figure 7.2: Data parallelism — replicated model, sharded batch, all-reduced gradients

flowchart TD Batch[Global Mini-Batch] --> Split{Split Across N GPUs} Split --> S1[Shard 1] Split --> S2[Shard 2] Split --> S3[Shard 3] Split --> S4[Shard 4] S1 --> G1[GPU 1: Full Model Replica] S2 --> G2[GPU 2: Full Model Replica] S3 --> G3[GPU 3: Full Model Replica] S4 --> G4[GPU 4: Full Model Replica] G1 --> AR[NCCL All-Reduce: Average Gradients] G2 --> AR G3 --> AR G4 --> AR AR --> U1[Identical Optimizer Step on Every Replica]

Figure 7.3: ZeRO sharding stages — progressive partitioning of training state

graph TD Base[Baseline DDP: Every Rank Holds Full State] --> Z1 subgraph Z1[ZeRO-1] Z1A[Params: Replicated] Z1B[Gradients: Replicated] Z1C[Optimizer States: SHARDED] end Z1 --> Z2 subgraph Z2[ZeRO-2] Z2A[Params: Replicated] Z2B[Gradients: SHARDED] Z2C[Optimizer States: SHARDED] end Z2 --> Z3 subgraph Z3[ZeRO-3 / FSDP] Z3A[Params: SHARDED] Z3B[Gradients: SHARDED] Z3C[Optimizer States: SHARDED] end Z3 --> Off[+ CPU/NVMe Offload: Push State Beyond GPU RAM]

Figure 7.4: NCCL ring all-reduce — bandwidth-optimal aggregation

flowchart LR G0[GPU 0
chunk A] -->|send chunk| G1[GPU 1
chunk B] G1 -->|send chunk| G2[GPU 2
chunk C] G2 -->|send chunk| G3[GPU 3
chunk D] G3 -->|send chunk| G0

Animation A1 — Data Parallelism with Ring All-Reduce

Global Mini-Batch GPU 0 Full Model Copy Batch shard 1 GPU 1 Full Model Copy Batch shard 2 GPU 2 Full Model Copy Batch shard 3 GPU 3 Full Model Copy Batch shard 4 NCCL ring all-reduce — gradients averaged around the ring

Each GPU receives a different batch shard (amber), computes locally on its full model copy, then gradients flow around the ring (blue) so every replica ends with the averaged gradient before the optimizer step.

Animation A2 — ZeRO/FSDP Progressive Sharding

GPU 0 GPU 1 GPU 2 GPU 3 ZeRO-1 opt sharded params grads opt ZeRO-2 + grads params grads opt ZeRO-3 FSDP — all sharded params grads opt params gradients optimizer states dashed border = replicated

As you climb ZeRO stages, more of each row "sharded" pulses in — first optimizer states (Stage 1), then gradients (Stage 2), then parameters (Stage 3 / FSDP). Total per-GPU memory drops toward 1/N of the model.

Comparing Parallelism Strategies

AspectData ParallelTensor / ModelPipelineZeRO-3 / FSDP
Model fits on one GPU?RequiredNot requiredNot requiredNot required
Main memory benefitNone for paramsPer-GPU param mem shrunkPer-stage params + actsAll states sharded
Communication patternAll-reduce per stepCollectives inside layersActivations at boundariesAll-gather + reduce-scatter per layer
Communication frequencyOnce per iterationPer TP layerPer microbatch per boundaryPer layer
GPU utilizationTypically highHigh if layer is largeLimited by bubblesHigh with prefetching
ComplexityEasiestHighHighModerate
Best scale axisBatch sizeLayer widthModel depthTotal parameter count

Frontier LLM training combines all three into 3D parallelism: tensor parallel within a node, pipeline parallel across nodes, data parallel across replicas.

Key Points — Parallelism

Post-Reading Check — Parallelism

1. In standard data parallelism, what is communicated between GPUs each step?

Activations at every layer boundary. Gradients, averaged across replicas via an all-reduce. Parameters, gathered each forward pass via all-gather. Random number generator state.

2. What does ZeRO Stage 3 (FSDP) shard across data-parallel ranks that ZeRO Stage 1 does not?

Only optimizer states. Only gradients. Optimizer states, gradients, and parameters. Activation tensors only.

3. Why is tensor parallelism typically kept WITHIN a single node?

PyTorch only supports tensor parallelism on a single node. Tensor parallelism triggers fine-grained collectives per layer, which need NVLink-level bandwidth. Inter-node InfiniBand is faster than intra-node NVLink. Tensor parallelism cannot use NCCL across nodes.

Section 3: Frameworks and Tools for Distributed Training

Pre-Reading Check — Frameworks

1. When should you migrate from PyTorch DDP to FSDP?

Whenever you have more than 4 GPUs, regardless of model size. When the model + optimizer state no longer fits on one GPU and OOM appears. When you need stronger gradient accumulation than DDP allows. Only when training across heterogeneous frameworks like TensorFlow.

2. What capability distinguishes DeepSpeed from pure FSDP for the very largest models?

DeepSpeed supports BF16 while FSDP does not. CPU/NVMe offload of parameters, gradients, and optimizer states beyond GPU RAM, plus integrated pipeline and Megatron-LM tensor parallelism. DeepSpeed uses MPI instead of NCCL for collectives. DeepSpeed runs without a launcher.

3. Which best describes Ray Train's role?

A new gradient compression algorithm replacing NCCL. An orchestration and abstraction layer that wraps DDP/FSDP/DeepSpeed behind a uniform Python API and handles scaling, fault tolerance, and checkpoints. A drop-in replacement for NCCL on TPUs. A standalone tensor parallelism implementation.

PyTorch DDP

DDP is the canonical PyTorch data-parallel implementation: each rank holds a full copy of parameters, gradients, and optimizer state; gradients are bucketed and all-reduced via NCCL during backward. Launched with torchrun, which spawns one process per GPU. The right tool for models that fit comfortably on one GPU scaling up to 8–16 GPUs.

PyTorch FSDP

Switch to FSDP when the model no longer fits. ZeRO-3-style full sharding: all-gather before each layer's forward, reduce-scatter after backward. More NCCL calls than DDP but per-rank memory drops by roughly 1/N. The tricky part is the auto-wrap policy — wrap each transformer block. Integrates with torch.distributed.checkpoint for sharded checkpoint I/O.

Horovod

Framework-agnostic ring-allreduce library originally from Uber, works with PyTorch, TF, and MXNet. Default in heterogeneous shops circa 2018–2020. In PyTorch-only environments, DDP eclipsed it for the same use cases; FSDP and DeepSpeed eclipsed it for memory-bound workloads.

DeepSpeed and Megatron-LM

Microsoft's batteries-included stack for very large training. Adds:

Configuration overhead is real: a JSON config with dozens of knobs.

Ray Train

An orchestration layer: you write a single training function using DDP/FSDP/DeepSpeed internally, then ask Ray to run it on a cluster. Ray handles scaling, fault tolerance, Ray Tune integration, and checkpoint management. Abstracts away the launcher (torchrun vs deepspeed vs horovodrun).

Migration Path

DDP → FSDP when OOM hits. FSDP → DeepSpeed when you need offload, pipeline parallel, or Megatron-LM TP.

Figure 7.5: Framework decision flow

flowchart TD Start[Start: Pick a Distributed Training Framework] --> Q1{Model fits on 1 GPU
with optimizer state?} Q1 -->|Yes, ≤16 GPUs| DDP[PyTorch DDP] Q1 -->|No| Q2{Need CPU/NVMe offload
or 3D parallelism?} Q2 -->|No - pure PyTorch is fine| FSDP[PyTorch FSDP - ZeRO-3] Q2 -->|Yes| DS[DeepSpeed + Megatron-LM] Q1 -->|Multi-framework / MPI shop| Horo[Horovod] DDP --> Orch{Need cloud-native orchestration,
HPO, fault tolerance?} FSDP --> Orch DS --> Orch Horo --> Orch Orch -->|Yes| Ray[Wrap with Ray Train] Orch -->|No| Done[Launch with native launcher]

Key Points — Frameworks

Post-Reading Check — Frameworks

1. When should you migrate from PyTorch DDP to FSDP?

Whenever you have more than 4 GPUs, regardless of model size. When the model + optimizer state no longer fits on one GPU and OOM appears. When you need stronger gradient accumulation than DDP allows. Only when training across heterogeneous frameworks like TensorFlow.

2. What capability distinguishes DeepSpeed from pure FSDP for the very largest models?

DeepSpeed supports BF16 while FSDP does not. CPU/NVMe offload of parameters, gradients, and optimizer states beyond GPU RAM, plus integrated pipeline and Megatron-LM tensor parallelism. DeepSpeed uses MPI instead of NCCL for collectives. DeepSpeed runs without a launcher.

3. Which best describes Ray Train's role?

A new gradient compression algorithm replacing NCCL. An orchestration and abstraction layer that wraps DDP/FSDP/DeepSpeed behind a uniform Python API and handles scaling, fault tolerance, and checkpoints. A drop-in replacement for NCCL on TPUs. A standalone tensor parallelism implementation.

Section 4: Cost and Throughput Optimization

Pre-Reading Check — Cost

1. Why does BF16 typically NOT need dynamic loss scaling while FP16 does?

BF16 has more mantissa bits than FP16. BF16 has the same exponent range as FP32 (8 bits), so gradients rarely underflow. BF16 is computed entirely on CPU. BF16 only works during inference.

2. Gradient accumulation is most useful for which situation?

Maintaining a target global batch size after elastic shrinkage from 8 to 4 nodes. Increasing GPU SM utilization above 100%. Replacing the need for an optimizer. Reducing total compute by skipping micro-batches.

3. What does torchrun --nnodes=2:8 --max-restarts=5 enable?

Static training with exactly 5 nodes. Elastic training that tolerates node preemptions, re-rendezvousing between 2 and 8 nodes with up to 5 restarts. A fixed 2-node job that runs 5 times for hyperparameter search. A debug mode that disables NCCL.

Mixed Precision and BF16

Tensor cores run 16-bit matmuls at 4–8x FP32 throughput. FP16 has only 5 exponent bits, so gradients can underflow — AMP with a GradScaler multiplies the loss before backward and unscales before the optimizer step. BF16 uses 8 exponent bits (FP32 range) at the cost of mantissa bits, and usually does not need loss scaling. A100/H100 have first-class BF16; H100 adds FP8.

Gradient Checkpointing and Accumulation

Gradient checkpointing drops most activations in forward and recomputes them in backward — 30–50% activation memory savings at 20–40% extra compute. Gradient accumulation sums gradients over accum_steps micro-batches before stepping, simulating larger global batch without raising per-GPU memory. Particularly useful when spot preemption forces the job from 8 nodes down to 4 — bump accumulation steps to maintain the same effective global batch.

Spot, Preemptible, and Elastic Training

Cloud spot/preemptible instances offer 50–80% discounts but can be reclaimed with short notice. The key is elastic, fault-tolerant launch via torchrun: --nnodes=2:8 --max-restarts=5 --rdzv-backend=c10d. Membership changes become routine, not fatal.

For elastic to work, the script must be stateless across restarts:

Snapshot every 5–15 minutes; install a 2-minute termination-notice handler for emergency checkpoints.

Profiling

Targets: SM utilization > 70%, NCCL time < 30% of step time, per-step time stable across ranks.

GPU Selection Cheat Sheet

GPUMemoryBest forCaveats
H10080 GB HBM3Frontier LLM training, FP8Most expensive, newer ecosystem
A10040/80 GB HBM2eLarge transformers with BF16 AMPHigh $/hr; needs full utilization
V10016/32 GB HBM2Medium models with FP16 AMPNo BF16, aging
T416 GB GDDR6Small models, inference, prototypeLow memory

The metric that matters is cost per training token or sample, not list price per GPU-hour.

Figure 7.6: Spot preemption + elastic checkpoint recovery lifecycle

sequenceDiagram participant Cloud as Cloud Spot Pool participant TR as torchrun (elastic) participant Job as Training Workers participant S3 as Durable Storage (S3/GCS) TR->>Job: Rendezvous N nodes, start training Job->>S3: Snapshot every 5-15 min (model+opt+scaler+RNG) Cloud-->>Job: 2-minute preemption notice on Node K Job->>S3: Emergency snapshot (termination handler) Cloud-->>Job: Node K reclaimed TR->>TR: Detect membership change, re-rendezvous TR->>Job: Resume with N-1 nodes (within --nnodes=min:max) Job->>S3: Load latest snapshot Job->>Job: Resume from saved global step + RNG state Cloud-->>TR: New spot capacity available TR->>Job: Re-rendezvous, scale back up to N

Animation A3 — Mixed Precision Flow with Loss Scaling

FP32 Master Weights FP16/BF16 Cast for compute (tensor cores) Forward Pass FP16/BF16 matmuls Loss (FP32) + scale factor (FP16 only) Backward Pass FP16/BF16 gradients Unscale + check Inf/NaN (FP16 only) Optimizer Step on FP32 Master Weights AdamW: m, v in FP32 — preserves convergence behavior x scale unscale FP32 = blue border · FP16/BF16 = amber border · scale/unscale flags appear only in FP16 mode

Master weights stay FP32 (blue). They cast to FP16/BF16 (amber) for tensor-core forward and backward. FP16 multiplies loss by a scale factor before backward and unscales gradients before the optimizer step on the FP32 master — BF16 skips both because its FP32-range exponent avoids underflow.

Key Points — Cost

Post-Reading Check — Cost

1. Why does BF16 typically NOT need dynamic loss scaling while FP16 does?

BF16 has more mantissa bits than FP16. BF16 has the same exponent range as FP32 (8 bits), so gradients rarely underflow. BF16 is computed entirely on CPU. BF16 only works during inference.

2. Gradient accumulation is most useful for which situation?

Maintaining a target global batch size after elastic shrinkage from 8 to 4 nodes. Increasing GPU SM utilization above 100%. Replacing the need for an optimizer. Reducing total compute by skipping micro-batches.

3. What does torchrun --nnodes=2:8 --max-restarts=5 enable?

Static training with exactly 5 nodes. Elastic training that tolerates node preemptions, re-rendezvousing between 2 and 8 nodes with up to 5 restarts. A fixed 2-node job that runs 5 times for hyperparameter search. A debug mode that disables NCCL.

Your Progress

Answer Explanations