Data Parallelism
This is the second part of the Distributed Training for Dummies series.
The code for this series is available in the Distributed Training For Dummies GitHub repository. The design of much of the code is inspired by the picotron library. While the code itself contains detailed comments, here I will walk through key code blocks to provide a quick and basic understanding of how to implement some distributed training techniques. To start setting up the environment, you can visit the GitHub repository or follow the detailed walkthrough here.
Note: There are highly efficient libraries available, such as Megatron and DeepSpeed, which are specifically optimized for distributed training. I recommend using them whenever possible.
Reference section in the book: Explore HuggingFace’s Data Parallelism section here.
Data Parallelism involves replicating the model across multiple GPUs. Each GPU processes a microbatch independently—performing a forward pass and computing gradients during the backward pass. These gradients are then synchronized across devices using an all_reduce
operation, after which the optimizer updates the model parameters.
The timing of gradient computation and communication opens up several variants of data parallelism (e.g., synchronous, asynchronous, overlapping), some of which I experimented with in the code.
A microbatch refers to the portion of the input batch processed by a single device in a single pass. This is different from gradient accumulation, where multiple microbatches are processed sequentially (on the same GPU) to simulate a larger batch size before a gradient update is applied.
Note: The toy examples provided here do not cover sharding. Parallelism involves executing computations simultaneously across multiple devices, whereas sharding refers to distributing parameters (such as model weights, gradients, or optimizer states) across devices so that each device handles different parameters. Naive Data Parallelism (DP) does not involve sharding, but the Zero Redundancy Optimizer (ZeRO) incorporates sharding alongside DP.
Naive DP
To run the example at your end, launch the following command
RUN_MODE=data_parallel_naive docker-compose up --build
Naively, we can synchronize gradients after they have been fully computed for each minibatch on each process. However, this approach is inefficient because all GPUs must wait until every gradient is available before proceeding to the next step—resulting in idle time and poor hardware utilization.
The following code snippet accomplishes this
dist.barrier() # wait for all ranks to finish the current batch
# Important to sync before gradients are accumulated.
# Accumulate and sync gradients across ranks.
# sum gradients across ranks and sync gradients across ranks
for param in model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= dp_world_size
Overlap Communication with Computation
To run the example at your end, launch the following command
RUN_MODE=data_parallel_overlap docker-compose up --build
To make it efficient, we can synchornize gradients as soon as they are available, thereby overlapping communication with computation. To achieve this, we use hooks in PyTorch. A hook is a function that can be attached to a tensor (like a parameter’s gradient), and it will be called automatically when that tensor’s gradient is computed during backpropagation.
Attaching a hook: The following code ensures that as soon as the backward pass computes the gradient for a parameter, my_hook_function
is triggered for that parameter.
for param in model.parameters():
if param.requires_grad:
param.register_hook(my_hook_function)
All Reduce: Inside the hook, we perform an all_reduce
operation.
This operation communicates the gradient across all processes, so every process has the same, synchronized gradient.
By doing this inside the hook, communication starts as soon as the gradient is available, overlapping with the rest of the backward computation.
def my_hook_function(grad):
# Synchronize the gradient across all processes
dist.all_reduce(grad, op=dist.ReduceOp.SUM)
grad /= world_size # Average the gradient
return grad
Bucket Overlap
To run this example at your end, launch the following command
RUN_MODE=data_parallel_bucket_overlap docker-compose up --build
The approach above is also inefficient in another way: it involves too much communication all at once, which can become a bottleneck. When communication is not carefully coordinated, it can hinder computation by blocking GPU execution, leading to underutilized hardware and slower training. This can be prevented by dividing the parameters into “buckets”, and when their gradients are ready, initiating asyncrhonous communication to sync these gradients. This allows us to overlap communication for some parameters with computation for others.
To synchronize gradients efficiently, we group parameters into buckets and communicate them collectively across processes. Instead of sending gradients one by one, we wait until all gradients in a bucket are ready, then aggregate and communicate them as a single tensor—reducing communication overhead.
As each parameter finishes backpropagation, we mark its gradient as ready. Once all gradients within a bucket are available, we trigger synchronization for the entire bucket in one go, ensuring better overlap of communication and computation.
Marking Bucket as Ready. Each time a parameter’s gradient is ready, we mark it as such. When all parameters in a bucket are ready, we trigger synchronization for the whole bucket.
def mark_param_as_ready(self, param):
self.params_with_grad_ready.add(param)
if len(self.params_with_grad_ready) == len(self.params):
self.sync_gradients()
This function is called whenever a parameter’s gradient is computed. Once all parameters in the bucket are ready, we call sync_gradients()
.
Syncing and Waiting (Asynchronously). When synchronizing gradients, we use an asynchronous all-reduce operation. Later, we wait for all asynchronous operations to finish before proceeding.
sync_gradients
starts an asynchronous communication operation for the bucket’s gradients. wait
ensures that all asynchronous communications are complete before moving on (e.g., before the optimizer step).
def sync_gradients(self):
# Start asynchronous all-reduce for the bucket's gradients
self.grads.div_(self.process_group_size)
self.handle = dist.all_reduce(self.grads, op=dist.ReduceOp.SUM, async_op=True)
def wait(self):
# Wait for the asynchronous operation to complete
if self.handle is not None:
self.handle.wait()
Purpose of main_grad
and How It Enables Distributed Communication. In the bucket overlap approach, we want to efficiently synchronize gradients for groups of parameters (buckets) across processes. To do this, we need a way to collect and manage the gradients for all parameters in a bucket as a single tensor, which can then be communicated efficiently.
For each parameter, instead of using the default .grad
attribute, we create a view into a larger, contiguous tensor that holds the gradients for all parameters in the bucket. This view is stored as param.main_grad
.
When the backward pass computes a gradient for a parameter, it is added to param.main_grad
(not .grad
), so all gradients for the bucket are accumulated in the bucket’s tensor. This allows us to perform a single, efficient all-reduce operation on the entire bucket’s gradients, rather than many small operations for each parameter.
# For each parameter, create a view into the bucket's gradient tensor
for param in self.module.parameters():
if param.requires_grad:
data_start_idx, data_end_idx, bucket_idx = self._param_to_bucket_location[param]
param.main_grad = self.grad_list[bucket_idx][data_start_idx:data_end_idx].view(param.shape)
Use of queue_callback
and When post_backward
Is Called. torch.autograd.Variable._execution_engine.queue_callback
is a low-level PyTorch API that lets you schedule a callback to be run after the entire backward pass is complete.
In the bucket overlap code, we want to wait for all asynchronous all-reduce operations (for all buckets) to finish after the backward pass.
We only want to do this once per backward pass, not every time a parameter’s gradient is computed. During gradient accumulation, for example, we deliberately avoid syncing after each backward pass since we’re just accumulating gradients locally.
So, the first time a parameter’s hook is called (when require_backward_sync
is True), we register the post_backward
callback.
if not self._post_backward_callback_set:
torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True
The post_backward
Function. This function is called after the backward pass is complete.
It waits for all asynchronous all-reduce operations (one per bucket) to finish, ensuring all gradients are synchronized before the optimizer step.
It also resets the state for the next iteration.
def _post_backward(self, *unused):
# Wait for all buckets' async all-reduce to finish
for bucket in self.buckets:
bucket.wait()
self._post_backward_callback_set = False
# Reset buckets for the next backward pass
for bucket in self.buckets:
bucket.reset()