Tensor Parallelism

Toy setting to understand Tensor Parallelism.

July 8, 2025


This is the third part of the Distributed Training for Dummies series.

The code for this series is available in the Distributed Training For Dummies GitHub repository. The design of much of the code is inspired by the picotron library. While the code itself contains detailed comments, here I will walk through key code blocks to provide a quick and basic understanding of how to implement some distributed training techniques. To start setting up the environment, you can visit the GitHub repository or follow the detailed walkthrough here.

Note: There are highly efficient libraries available, such as Megatron and DeepSpeed, which are specifically optimized for distributed training. I recommend using them whenever possible.


Reference section in the book: Explore HuggingFace’s Tensor Parallelism section here.

Training neural networks requires storing model parameters, gradients, optimizer states, and activations in memory. Among these, activations can be particularly memory-intensive, since the outputs of all intermediate operations must be stored to enable backpropagation. In fact, activation memory typically grows quadratically with model depth and width.

To reduce activation memory usage, one practice is to discard some activations during the forward pass and recompute them during the backpropagation. This is called activation checkpointing. Another option is to offload activations to disk or CPU memory at the cost of I/O overhead.

An alternative strategy is to cleverly distribute activations across devices by parallelizing key operations like matrix multiplications. This is called Tensor Parallelism.

Consider a typical linear layer in a neural network:

  • Let \( X \in \mathbb{R}^{S \times C_{\text{in}}} \) be the input matrix, where
    • \( C_{in} \) the number of input channels
    • \( S \) is the input sequence length.
  • Let \( W \in \mathbb{R}^{C_{\text{in}} \times C_{\text{out}}} \) be the weight matrix, where \( C_{out} \) is the number of output channels.

The matrix multiplication:

\[ XW \in \mathbb{R}^{S \times C_{\text{out}}} \]

produces the output activations. However, when \( W \) is large, storing it – and the corresponding activations needed for the backward pass – can exceed the memory limits of a single device.

To address this, we shard the weight matrix \( W \) and parallelize the matmul operation across multiple devices. We can shard \( W \) in two ways – either column-wise, known as column linear parallelism, or row-wise, known as row linear parallelism.

Column Linear Parallelism

Split \( W \) along its output dimension:

\[ W = [W_1, W_2, \dots, W_n], \quad W_i \in \mathbb{R}^{C_{\text{in}} \times \frac{C_{\text{out}}}{n}} \]

Broadcast \( X \in \mathbb{R}^{S \times C_{\text{in}}} \) to all devices.

Each device computes:

\[ XW_i \in \mathbb{R}^{S \times \frac{C_{\text{out}}}{n}} \]

Gather (concatenate) the results in their respective indices to obtain \( XW \in \mathbb{R}^{S \times C_{\text{out}}} \)

Row Linear Parallelism

Shard \( W \) along its input dimension:

\[ W = \begin{bmatrix} W_1 \ W_2 \ \vdots \ W_n \end{bmatrix}, \quad W_i \in \mathbb{R}^{\frac{C_{\text{in}}}{n} \times C_{\text{out}}} \]

Scatter (shard) \( X \) across devices \( X_i \in \mathbb{R}^{S \times \frac{C_{\text{in}}}{n}} \), and compute:

\[ X_i W_i \in \mathbb{R}^{S \times C_{\text{out}}} \]

All Reduce: The partial results are summed to obtain the final output.

Column Linear followed by Row Linear

Notice that the output of Column Linear Parallelism consists of chunks of the full output tensor along the output (outer) dimension, which are typically gathered at the end to reconstruct the complete result. In contrast, Row Linear Parallelism begins by sharding the input tensor along its input (outer) dimension, distributing the workload across devices from the start.

This pairing creates a natural pipeline: the column-parallel output is already in the right format to serve as the row-parallel input—eliminating the need for inter-device communication between the two operations.

This design enables the common Feedforward → Activation → Feedforward pattern—used in MLP blocks and multi-head attention—to execute without communicating intermediate activations across devices, improving efficiency and reducing latency.

Using Tensor Parallelism for MLP

Given: \( X \in \mathbb{R}^{S \times C_{\text{in}}} \), \( W \in \mathbb{R}^{C_{\text{in}} \times C_{\text{out}}^w} \), \( Y \in \mathbb{R}^{C_{\text{out}}^w \times C_{\text{out}}^y} \)

In a two layer MLP, we are interested in computing \( \sigma(X W) Y \) using TP across two processors. This will be the sequence of operations.

Operation Sequence Communication Operation Processor 1 Processor 2
1 Broadcast \( X \in \mathbb{R}^{S \times C_{\text{in}}} \) \( X \in \mathbb{R}^{S \times C_{\text{in}}} \)
  Shard Weights for Column Linear (No Communication Needed; Weights Pre-allocated) \( W_1 \in \mathbb{R}^{C_{\text{in}} \times C_{\text{out}_1}^w} \) \( W_2 \in \mathbb{R}^{C_{\text{in}} \times C_{\text{out}_2}^w} \)
2 Column Linear (local matmul) \( X W_1 \in \mathbb{R}^{S \times C_{\text{out}_1}^w} \) \( X W_2 \in \mathbb{R}^{S \times C_{\text{out}_2}^w} \)
3 No communication (activation) \( \sigma(X W_1) \in \mathbb{R}^{S \times C_{\text{out}_1}^w} \) \( \sigma(X W_2) \in \mathbb{R}^{S \times C_{\text{out}_2}^w} \)
  Shard weights for Row Linear (No Communication Needed; Weights Pre-allocated) \( Y_1 \in \mathbb{R}^{C_{\text{out}1}^w \times C_{\text{out}}^y} \) \( Y_2 \in \mathbb{R}^{C_{\text{out}2}^w \times C_{\text{out}}^y} \)
4 Row Linear (local matmul) \( \sigma(X W_1) Y_1 \in \mathbb{R}^{S \times C_{\text{out}}^y} \) \( \sigma(X W_2) Y_2 \in \mathbb{R}^{S \times C_{\text{out}}^y} \)
5 All-Gather \( [\sigma(X W_1) Y_1,\, \sigma(X W_2) Y_2] \in \mathbb{R}^{S \times C_{\text{out}}^y} \) \( [\sigma(X W_1) Y_1,\, \sigma(X W_2) Y_2] \in \mathbb{R}^{S \times C_{\text{out}}^y} \)

Using Tensor Parallelism for Self-Attention

In the attention mechanism, computation is fully parallelized by distributing the workload across multiple heads, with each head being completely independent. This allows each GPU to perform its own computations using these heads without the need for inter-GPU communication. Once self-attention is computed, the outputs from these heads are processed through a row-parallel output projection, creating a partial view of the self-attention output. These partial outputs are then combined to form the final result. Let’s explore this process with two GPUs.

Operation Sequence Communication Operation Processor 1 (GPU 0) Processor 2 (GPU 1)
1 Broadcast activations \( X \in \mathbb{R}^{S \times d_{\text{model}}} \) \( X \in \mathbb{R}^{S \times d_{\text{model}}} \)
  Shard weights for column‑parallel QKV (pre‑allocated, no comm) \( W_{qkv}^{(0)} \in \mathbb{R}^{d_{\text{model}} \times \tfrac{3d_{\text{model}}}{2}} \) \( W_{qkv}^{(1)} \in \mathbb{R}^{d_{\text{model}} \times \tfrac{3d_{\text{model}}}{2}} \)
2 Column Linear (local matmul) \( [Q,K,V]^{(0)} = X\,W_{qkv}^{(0)} \) \( [Q,K,V]^{(1)} = X\,W_{qkv}^{(1)} \)
3 No communication: scaled‑dot‑product softmax on local heads \( H^{(0)} = \operatorname{Softmax}\bigl(Q^{(0)}K^{(0)\top}/\sqrt{d_h}\bigr)V^{(0)} \) \( H^{(1)} = \operatorname{Softmax}\bigl(Q^{(1)}K^{(1)\top}/\sqrt{d_h}\bigr)V^{(1)} \)
  Shard weights for row‑parallel output (pre‑allocated, no comm) \( W_{o}^{(0)} \in \mathbb{R}^{\tfrac{d_{\text{model}}}{2} \times d_{\text{model}}} \) \( W_{o}^{(1)} \in \mathbb{R}^{\tfrac{d_{\text{model}}}{2} \times d_{\text{model}}} \)
4 Row Linear (local matmul) \( Z^{(0)} = H^{(0)} W_{o}^{(0)} \in \mathbb{R}^{S \times d_{\text{model}}} \) \( Z^{(1)} = H^{(1)} W_{o}^{(1)} \in \mathbb{R}^{S \times d_{\text{model}}} \)
5 All‑Reduce (sum) \( Z = Z^{(0)} + Z^{(1)} \) \( Z = Z^{(0)} + Z^{(1)} \)

Tensor Parallelism in code

To run this example at your end, launch the following command

RUN_MODE=tensor_parallel docker-compose up --build

Let’s review the key operations in the train.py file, which employs Tensor Parallelism (TP) to train a three-layer MLP model.

Sharding the Linear Layer

Sharding in ColumnParallelLinear: In column linear parallelism, we shard the weight matrix column-wise. We calculate the size of the weight matrix assigned to each rank. In this implementation, the master weight is initialized on each process, and the corresponding segment is retained. Alternatively, initialization and distribution can be achieved using a scatter operation.

class ColumnParallelLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, tp_rank, gather_outputs=False):
        # ...
        self.tp_world_size = dist.get_world_size()
        self.out_features_per_rank = out_features // self.tp_world_size
        self.weight = torch.nn.Parameter(torch.Tensor(self.out_features_per_rank, in_features))
        self.bias = torch.nn.Parameter(torch.zeros(self.out_features_per_rank))
        self.reset_parameters()

    def reset_parameters(self):
        # ... initialize and shard master_weight ...
        weight_list = torch.split(master_weight, self.out_features_per_rank, dim=0)
        self.weight.data = weight_list[self.tp_rank].contiguous()

In row parallelism, the rows of the weight matrix are distributed across devices, with each device managing a portion of the input weights. Sharding in RowParallelLinear is quite similar to the ColumnParallelLinear layer, with the primary difference being the dimension along which sharding occurs.

Distributed Primitives in the Forward Pass

After sharding, we need to ensure that each device gets the right data and that outputs are combined correctly. This is where our distributed communication primitives come in.

In ColumnParallelLinear Forward, Copy.apply(x) ensures all ranks have the same input.

def forward(self, x):
    # Each rank gets a copy of the input tensor
    input_broadcasted = Copy.apply(x) # note the distribution is used as a PyTorch operation
    output = F.linear(input_broadcasted, self.weight, self.bias)
    return output

In RowParallelLinear Forward, Reduce.apply(...) sums partial outputs from all ranks so each gets the full result.

def forward(self, input):
    # Input is already sharded
    output_parallel = F.linear(input, self.weight)
    output = Reduce.apply(output_parallel, self.tp_rank)
    return output + self.bias

Communication Primitives as PyTorch modules

In distributed training, using each communication primitive as a PyTorch operation simplifies code management. For instance, the Copy primitive (below) is a custom autograd function that facilitates both forward and backward communication. In this setup, the forward function is a no-op since the input is already broadcasted, whereas the backward function aggregates the gradients from all devices to ensure accurate updates.

class Copy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # Forward: just return the input (broadcast handled outside)
        return input

    @staticmethod
    def backward(ctx, grad_output):
        # Backward: sum gradients across all ranks
        dist.all_reduce(grad_output, op=dist.ReduceOp.SUM)
        return grad_output

The full code, including the Gather and Reduce primitives, model definition, and training loop, is available in the repo. Dive in to see how everything fits together!