Distributed Training for Dummies
After enjoying the Ultra Scale Playbook
The code for this series is available in the Distributed Training For Dummies GitHub repository. The design of much of the code is inspired by the picotron
library. While the code itself contains detailed comments, here I will walk through key code blocks to provide a quick and basic understanding of how to implement these strategies. To start setting up the environment, you can visit the GitHub repository or follow the detailed walkthrough in this post.
Note: There are highly efficient libraries available, such as Megatron and DeepSpeed, which are specifically optimized for distributed training. I recommend using them whenever possible.
Just learning how many parameters today’s large language models (LLMs) have gives me chills. What fascinated me even more was discovering that many of their emergent capabilities arise simply from scaling them up. That curiosity sparked a deeper question: How have researchers and engineers actually managed to train such massive neural networks, sometimes across thousands of GPUs?
Fortunately, resources like the Ultra Scale Playbook
This blog series is my attempt to do just that. I’ll break down the basic concepts and gradually build up to the point where you can understand how these large-scale training strategies work in practice. You’re also welcome to get hands-on with the accompanying code in the Distributed Training For Dummies GitHub repository.
Communication vs Computation vs Memory
Maximizing the efficiency of a distributed training setup requires carefully balancing trade-offs across three key dimensions:
-
Computation throughput: The number of floating point operations per second (FLOPs) your model requires. This is bounded by the computational capacity of your hardware (e.g., GPUs or TPUs).
-
Communication overhead: The amount of data that must be exchanged across devices, constrained by the bandwidth and latency of interconnects (e.g., NVLink, Ethernet, Infiniband).
-
Memory usage: The memory required to store inputs, outputs, model weights, and gradients. This is often the first bottleneck when scaling up model sizes or batch sizes.
Over the years, innovations across all three of these areas have contributed to the impressive capabilities of today’s LLMs.
For example, attention mechanisms like multi-query attention
In this series, we’ll explore strategies for training large neural networks using distributed computing, aiming to improve efficiency across computation, communication, and memory, focusing on three key paradigms:
- Data Parallelism
- Tensor Parallelism
- Pipeline Parallelism
Each concept will be introduced through toy examples—training a simple MLP - simulated on a single machine with distributed capabilities. These examples aim to illustrate the ideas without needing access to thousands of GPUs.
Throughput vs Latency
Two performance metrics often come up in distributed training: throughput and latency.
-
Throughput refers to the volume of useful work completed per unit time—such as FLOPs, samples processed per second, or bytes transferred per second. Maximizing throughput means keeping compute units busy with meaningful operations as consistently as possible.
-
Latency measures the time it takes to complete a single unit of work—for example, the time taken to process one batch of data through the model. Minimizing latency involves reducing communication delays and using faster compute hardware.
In practice, distributed training pipelines aim to overlap communication with computation to maximize throughput and minimize latency.
Communication primitives
Distributed training uses a set of standard communication operations to move data between processes—whether across cores within a node or between nodes in a cluster. These data transfers are governed by the underlying hardware (e.g., PCIe, NVLink, Ethernet), each with its own bandwidth constraints that determine how much data can be moved per unit time.
During neural network training, I find it helpful to treat these communication operations conceptually like any other torch.nn.functional
operation—they’re just transformations on data. For instance, broadcasting data to 4 processes is functionally similar to an identity operation during the forward pass: all processes receive the same data.
While the forward pass is often easy to follow, the backward pass can be less intuitive. For example, when broadcasting is used in the forward pass (i.e., copying data), the backward pass must sum gradients from all the receiving processes—since each received a copy.
Below is a list of common communication primitives used in distributed training, along with what each operation does and what to expect during the forward and backward passes:
1. One-to-One (Point-to-Point)
(In the backward pass, the data movement is also one-to-one)
- Send/Receive
- Direct communication between two processes, where one process sends a tensor and another receives it.
- PyTorch:
torch.distributed.send
,torch.distributed.recv
- Forward Pass: The tensor is sent from one process and received by another. If you think of this as a module, it acts as an identity operation for the sender and receiver.
- Backward Pass: Gradients are not automatically communicated; you must manually send/receive gradients if needed. The backward pass is also an identity unless.
2. One-to-Many
(In the backward pass, the data movement is typically many-to-one)
- Broadcast
- Description: One process (the root) sends the same tensor to all other processes. Commonly used to synchronize model parameters at the start of training.
- PyTorch:
torch.distributed.broadcast
- Forward Pass: The root’s tensor is copied to all other processes. It is simply an identity operation.
- Backward Pass: Gradients from all processes are summed (reduced) back to the root. This ensures the root receives the total gradient from all workers.
- Scatter
- Description: The root process splits a tensor into chunks and sends each chunk to a different process. Useful for distributing input data.
- PyTorch:
torch.distributed.scatter
- Forward Pass: Each process receives a different chunk of the tensor from the root. Functionally, this acts as an identity operation—but applied to individual chunks rather than the entire tensor.
- Backward Pass: Gradients from all processes are gathered back to the root, but not summed—each gradient is placed in the corresponding position in the root’s tensor. Again, this acts like an identity function over chunks.
3. Many-to-One
(In the backward pass, the data movement is typically one-to-many)
- Gather
- Description: Each process sends its local tensor (a chunk of a larger whole) to the root process, which assembles them into a single, larger tensor. This is effectively the reverse of a scatter.
- PyTorch:
torch.distributed.gather
- Forward Pass: The root collects tensors from all processes and concatenates them. Functionally, this behaves like an identity operation over the constituent chunks.
- Backward Pass: The root’s gradient is split and scattered back to the original processes, so each process receives the gradient corresponding to its original input. Again, this acts like an identity over the chunks.
- Reduce
- Description: All processes send their tensor to the root, which aggregates them using a specified operation (e.g., sum, max). Commonly used for aggregating results or gradients.
- PyTorch:
torch.distributed.reduce
- Forward Pass: The root receives the aggregated tensor. The aggregation operation (e.g., sum) defines the forward behavior.
- Backward Pass: The root’s gradient is broadcast to all processes, so each process receives the same gradient. This makes the backward pass functionally equivalent to an identity operation across processes.
4. Many-to-Many (Collective)
(In the backward pass, the data movement is also many-to-many)
-
All-Gather
- Description: Similar to
gather
, but every process acts as the gatherer. Each process collects tensors from all others and ends up with the full concatenated result. All processes have the same output. - PyTorch:
torch.distributed.all_gather
- Forward Pass: Each process gathers tensors from all other processes and concatenates them. The subtle point is that each process already knows which chunk it contributed, and receives the others in a fixed order. Since all processes perform the same operation, they all hold the full tensor after the gather.
- Backward Pass: No communication is needed. Instead, each process slices the gradient of the full tensor to retain only the part corresponding to its original input. This is not a no-op—it involves a local slicing operation but avoids any data transfer.
- Description: Similar to
-
All-Reduce
- Description: Similar to
reduce
, but all processes participate equally. Each process contributes its tensor, and all end up with the aggregated result (e.g., sum, mean). - PyTorch:
torch.distributed.all_reduce
- Forward Pass: All processes receive the result of applying the reduction operation (e.g., sum) across all inputs.
- Backward Pass: The same reduction is applied to gradients. All processes receive the aggregated gradient. The operation is symmetric in forward and backward.
- Description: Similar to
-
All-to-All
- Description: Each process splits its tensor into chunks and sends a unique chunk to every other process. Simultaneously, it receives one chunk from each of them. Unlike
all-gather
orall-reduce
, where all processes end up with the same result,all-to-all
results in each process having a different output. - PyTorch:
torch.distributed.all_to_all
- Forward Pass: Each process sends chunks of its tensor to all other processes and receives a chunk from each. The output is a rearranged tensor composed of parts from every peer.
- Backward Pass: The pattern is reversed: each process sends back gradients for the chunks it received and receives gradients for the chunks it originally sent. This mirrors the
all-to-all
structure in reverse. PyTorch autograd handles this routing internally when using native ops.
- Description: Each process splits its tensor into chunks and sends a unique chunk to every other process. Simultaneously, it receives one chunk from each of them. Unlike
Note: In practice, PyTorch’s autograd system automatically computes the correct backward operation for these collectives.
Summary Table
Data Movement | Operation | PyTorch Function & Docs | Forward Pass | Backward Pass |
---|---|---|---|---|
One-to-One | Send/Receive | send, recv | Tensor sent from one process and received by another (identity for sender/receiver) | Gradients not automatically communicated; typically identity unless custom logic |
One-to-Many | Broadcast | broadcast | Root's tensor copied to all others | Gradients from all processes are summed (reduced) to root |
Scatter | scatter | Each process receives a chunk from root | Gradients gathered back to root, placed in corresponding positions | |
Many-to-One | Gather | gather | Root collects and concatenates tensors | Root's gradient split and scattered back to original processes |
Reduce | reduce | Root receives aggregated tensor | Root's gradient broadcast to all processes | |
Many-to-Many | All-Gather | all_gather | Each process collects and concatenates tensors from all others | Each process retains gradient for its original chunk |
All-Reduce | all_reduce | All receive aggregated tensor | Aggregation performed on gradients, all receive result | |
All-to-All | all_to_all | Each receives set of chunks from all others | Each receives gradients for chunks it originally sent |
Setting up the environment
To simulate a distributed computing infrastructure on a single machine, there are two main options:
-
Use
torchrun --nproc_per_node=2
to launch your training code.
When running on a CPU, this command will create two separate processes. If you are using the latest version of PyTorch, the distributed package comes bundled with it, so you don’t need to install anything extra. Otherwise, you can install PyTorch. -
Set up a Docker environment to spin up multiple processes. You can configure a Docker setup with multiple worker containers, each running the same training script. By creating a training script that includes communication primitives, you can simulate distributed training with much finer control than the previous option.
To set up the docker environment, follow the steps in this repository.
Specifically, we will simulate an environment with two machines accessible at MASTER_ADDR:MASTER_PORT
. Each worker is assigned a rank, which can be referenced in the script. All variables defined in the environment
section will be accessible in the script via the os.getenv
function. Our docker-compose.yaml
will look like this:
version: '3'
services:
worker0:
build: .
environment:
- RANK=0
- WORLD_SIZE=2
- MASTER_ADDR=worker0
- MASTER_PORT=12355
- RUN_MODE=${RUN_MODE}
networks:
- distnet
worker1:
build: .
environment:
- RANK=1
- WORLD_SIZE=2
- MASTER_ADDR=worker0
- MASTER_PORT=12355
- RUN_MODE=${RUN_MODE}
networks:
- distnet
networks:
distnet:
Example 1: Broadcast and All-Gather operations
In the following example, we will simulate a scenario where:
- One worker creates a tensor of numbers.
- This tensor is broadcast to all workers.
- Each worker takes a chunk of the data, performs a simple computation, and then all results are gathered back together on all the devices.
To run the example at your end, launch the following command
RUN_MODE=comm docker-compose up --build
Initialization: The script will initialize the Distributed Process Group (dist.init_process_group
) to set up communication between all workers so they can share data.
dist.init_process_group(
backend='gloo', # this backend is suitable for CPU-based training; for GPU-based training, use 'nccl'
init_method=f'tcp://{master_addr}:{master_port}',
rank=rank,
world_size=world_size
)
Broadcasting: We use broadcasting operation (broadcast
) to send data from rank \( 0 \) to other workers.
Only rank \( 0 \) creates the data; broadcast ensures everyone gets the same tensor.
if rank == 0:
full_data = torch.arange(4) + random.randint(0, 100)
else:
full_data = torch.zeros(4, dtype=torch.long)
dist.broadcast(full_data, src=0)
Local computation: Each worker takes a slice of the data and does a simple computation.
chunk_size = full_data.size(0) // world_size
start = rank * chunk_size
end = start + chunk_size
local_data = full_data[start:end].float()
print(f"Rank {rank} full data: {full_data}")
print(f"Rank {rank} received data: {local_data}")
local_result = local_data * (rank + 1)
print(f"Rank {rank} local result: {local_result}")
All Gather: Finally, each worker sends its result to all others, so everyone ends up with the full output (all_gather
).
gathered_results = [torch.zeros_like(local_result) for _ in range(world_size)]
dist.all_gather(gathered_results, local_result)
full_result = torch.cat(gathered_results)
print(f"Rank {rank} final gathered result: {full_result}")
Running the above script, will return print statements from the individual worker which you can idenitfy by their RANK
.
distributed_training_for_dummies-worker1-1 | Rank 1 full data: tensor([37, 38, 39, 40])
distributed_training_for_dummies-worker1-1 | Rank 1 received data: tensor([39., 40.])
distributed_training_for_dummies-worker1-1 | Rank 1 local result: tensor([78., 80.])
distributed_training_for_dummies-worker1-1 | Rank 1 local_result shape: torch.Size([2])
distributed_training_for_dummies-worker1-1 | Rank 1 final gathered result: tensor([37., 38., 78., 80.])
distributed_training_for_dummies-worker0-1 | Rank 0 full data: tensor([37, 38, 39, 40])
distributed_training_for_dummies-worker0-1 | Rank 0 received data: tensor([37., 38.])
distributed_training_for_dummies-worker0-1 | Rank 0 local result: tensor([37., 38.])
distributed_training_for_dummies-worker0-1 | Rank 0 local_result shape: torch.Size([2])
distributed_training_for_dummies-worker0-1 | Rank 0 final gathered result: tensor([37., 38., 78., 80.])
distributed_training_for_dummies-worker1-1 exited with code 0
distributed_training_for_dummies-worker0-1 exited with code 0
Example 2: Using Distributed Dataloader in PyTorch for distributed training
We instantiate a distributed dataloader to enable batch distribution in distributed training. Each worker processes a unique subset of the dataset, ensuring no overlap, and communicates with other workers only as needed—typically during gradient synchronization or evaluation aggregation.
To run the example at your end, launch the following command:
RUN_MODE=train docker-compose up --build
Distributed Dataloader: After setting up a process group for communication across workers (using dist.init_process_group
), each worker loads a unique subset of the data using a distributed sampler.
# Create dataset
dataset = SimpleDataset()
# Create distributed sampler
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank, # this is how the sampler distributes microbatches across devices
shuffle=True,
seed=42 # Use same seed for reproducibility
)
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=4,
sampler=sampler
)
Local computation: Each worker processes its own batch and simulates a computation.
for epoch in range(2): # Run for 2 epochs
sampler.set_epoch(epoch) # Important for shuffling!
for batch_idx, data in enumerate(dataloader):
# Simulate computation
local_result = data * (rank + 1)
All Gather: Finally, each worker sends its result to all others, so everyone ends up with the full output.
# Gather results
gathered_results = [torch.zeros_like(local_result) for _ in range(world_size)]
dist.all_gather(gathered_results, local_result)
full_result = torch.cat(gathered_results)
print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}")
print(f"Local data: {data}")
print(f"Local result: {local_result}")
print(f"Gathered result: {full_result}")
Running the above script will print statements from each worker, which you can identify by their RANK
.
distributed_training_for_dummies-worker1-1 | Rank 1, Epoch 0, Batch 0
distributed_training_for_dummies-worker1-1 | Local data: tensor([618, 68, 215, 585])
distributed_training_for_dummies-worker1-1 | Local result: tensor([1236, 136, 430, 1170])
distributed_training_for_dummies-worker1-1 | Gathered result: tensor([ 542, 816, 94, 60, 1236, 136, 430, 1170])
distributed_training_for_dummies-worker1-1 | Rank 1, Epoch 1, Batch 0
distributed_training_for_dummies-worker1-1 | Local data: tensor([148, 575, 935, 343])
distributed_training_for_dummies-worker1-1 | Local result: tensor([ 296, 1150, 1870, 686])
distributed_training_for_dummies-worker1-1 | Gathered result: tensor([ 588, 223, 689, 932, 296, 1150, 1870, 686])
distributed_training_for_dummies-worker0-1 | Rank 0, Epoch 0, Batch 0
distributed_training_for_dummies-worker0-1 | Local data: tensor([542, 816, 94, 60])
distributed_training_for_dummies-worker0-1 | Local result: tensor([542, 816, 94, 60])
distributed_training_for_dummies-worker0-1 | Gathered result: tensor([ 542, 816, 94, 60, 1236, 136, 430, 1170])
distributed_training_for_dummies-worker0-1 | Rank 0, Epoch 1, Batch 0
distributed_training_for_dummies-worker0-1 | Local data: tensor([588, 223, 689, 932])
distributed_training_for_dummies-worker0-1 | Local result: tensor([588, 223, 689, 932])
distributed_training_for_dummies-worker0-1 | Gathered result: tensor([ 588, 223, 689, 932, 296, 1150, 1870, 686])
distributed_training_for_dummies-worker1-1 exited with code 0
distributed_training_for_dummies-worker0-1 exited with code 0
In the upcoming posts, I’ll walk through the code implementations of Data Parallelism, Tensor Parallelism, and Pipeline Parallelism.