Understanding GPU Programming with Triton: From Hardware to Code
A practical guide to writing efficient GPU kernels with Triton, covering GPU architecture, memory hierarchy, and three progressively complex examples
GPU CUDA Cores are specialized compute units that GPUs use to parallelize computation. These are arranged in a hierarchy that lets GPU fire-off commands to all the cores at the same time, thereby achieving parallelism. In this post, I will walk through the Hardware Model, Software Consideration, and how Triton enables to orchestrate these parallel processors using software abstraction.
I will start with a minimal Triton code that can be read (almost) like Python code. Then I will describe the hardware model, and then I will write a few kernels and measure how fast they operate. The code for these kernels can be found in this repo: https://github.com/pg2455/minimal_triton_tutorial
Triton Program
Any triton program will have the following 3 parts:
-
Imports: import triton, import triton.language as tl.
-
Kernel: Write your GPU function with the @triton.jit decorator.
-
Wrapper: Write a standard Python function that allocates input tensors and launches the kernel.
Finally, we create data and call the wrapper.
In the example of an add kernel below, observe:
kernel_fnis wrapped withtriton.jitwhich compiles it as per triton- Wrapper launches the kernel function on GPU using a special syntax:
kernel_fn[grid](args).[grid]tells the GPU how many workers to hire.- If grid = (10,), you hire 10 blocks of workers.
- If grid = (10, 5), you hire a 2D grid of \( 10 \times 5 \) blocks.
(args): These are the tools you give to every single worker.
import torch
import triton
import triton.language as tl
device = torch.device("cuda")
# 1. THE KERNEL (Runs on GPU)
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
# Get the program ID (like a thread ID)
pid = tl.program_id(axis=0)
# Create offsets (The "Spatial" logic)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to prevent out-of-bounds access
mask = offsets < n_elements
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# Do math
output = x + y
# Store data
tl.store(output_ptr + offsets, output, mask=mask)
# 2. THE LAUNCHER (Runs on CPU)
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
n_elements = x.numel()
# Define the Grid (How many workers?)
# We need enough blocks to cover all elements
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# Launch!
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
# 3. TEST IT
x = torch.rand(1000, device=device)
y = torch.rand(1000, device=device)
output = add(x, y)
print(output)
Now, observe the operations such as: tl.load, tl.store, which are all about memory management. These are moving data from GPU’s HBM to the memory close to the CUDA cores where work (maths) can happen. This example shows the basic structure, but to write efficient kernels, we need to understand what’s happening under the hood. Let’s look into the GPU hardware model.
Hardware Model:
What is the Memory Hierarchy? Data moves through multiple layers of memory, each with different size and speed characteristics. Listed from fastest/closest to slowest/farthest from the compute cores:
-
Thread Registers: Small, fast memory (~255×32 bits) dedicated to each thread for storing operands, loop variables, and intermediate results.
-
SRAM/L1 Cache: Each SM has dedicated SRAM (164-228 KB). Data from HBM is copied here, and threads within a block can share data through this memory.
-
L2 Cache: Shared across all SMs (40-50 MB). Caches frequently accessed data to reduce HBM access latency.
-
GPU HBM (High Bandwidth Memory): Main GPU memory (40-80 GB) where all kernel data must reside. All explicit loads/stores in your kernel access this memory (unless cached).
-
CPU RAM: System memory (256 GB - 2TB). Data must be transferred to GPU HBM via PCIe before kernels can operate on it.
Note: Moving data between these layers takes hundreds of cycles. Efficient kernels minimize data movement by maximizing compute on data already in fast memory.
What does GPU’s compute layout look like?
- A GPU has a 1D arrangement of Streaming Multiprocessors (SMs), which are collections of compute units.
- Each SM, also called a block in software terminology, can execute multiple warps concurrently (typically 32-64 warps per SM depending on architecture).
- Each Warp consists of 32 threads that execute in lockstep, performing the same instruction on different data (SIMD).
- Each thread knows how to do one fundamental operation: fused multiply-add (FMA). If a,b,c are scalars, each thread reads a,b,c into their dedicated register, computes a*b + c, and writes the result to their output register.
What compute units are operating in parallel?: Each warp in a block operate on the same kernel function. Each thread in the same warp, operate the same line in the kernel function.
Can the threads only do fused multiply and add (FMA)?: Yes, essentially, all CUDA Cores just do this elementwise. However, there are dedicated cores that can run multiple FMAs to perform complex operations:
- There are special functional units (SFUs) that compute complex expressions such sin(x), cos(x), exp(x), etc. These take several FMAs, and therefore are not as fast. This is a profound application of series in mathematics (e.g., Taylor Series), which define how to approximate any function into a polynomial expression, thereby making them series of FMAs. Note, these operations can be highly optimized by using the lookup tables as part of their series decomposition.
- Operations like 1/x or sqrt(x) require iterative optimization technique such as Newton-Raphson method, and hence these are also expressed as series of multiplies and adds.
Memory Layout: In GPUs, memory is laid down in a 1D arrangement. The matrices and tensors are defined using multiple dimension but that is for our convenience. Internally, these objects are defined using various strides. For example, a 3x2 matrix is defined using [2, 1] stride, i.e., the innermost stride as 1 and the outermost stride as 2. This means that to get to the next element in the innermost dimension, move the memory pointer by 1. To get to the next row, move the memory pointer by 2. If the tensor was 4x3x2, then the stride will be [6, 2, 1].
Memory Coalescing: When the data is read from HBM, it is usually read in two ways:
- 128 bytes or 32 floats using LDG.128 pipeline
- 32 bytes or 4 floats using LDG.32 pipeline.
This helps in increasing the efficiency of loading data from memory. For example, if addresses are next to each other, a single command will be enough to load 32 floats. Otherwise, we need more number of commands. Remember, that each load and copy consumes multiple cycles, and if we load more than we needed and then discard the rest, our compute efficiency goes down. This is exactly what is done by .contiguous() function.
Pytorch tries its best to keep the memory layout continuous. However, when it is not, it internally calls .contiguous() on a cloned view that lays down the memory in a continuous arrangement. If it is not needed, it resets these memory addresses. This results in a peak memory usage - for split seconds, the tensors were replicated. If this consumes more memory that available, it will error out.
How expensive is loading memory? Memory loads is not a single cycle of operation. Here is how each kernel_fn might work:
tl.load(ptr)will issue a command to load the data inptraddresses- Once the command is issued, the data takes another 400 cycles (as an example) to be copied to the thread register.
- Essentially,
tl.loadwill read from High Bandwidth Memory (HBM) of GPU and copy that data into theSRAM(dedicated to the block), and from there it is copied to the local thread registers.
- Essentially,
- Finally, when the data becomes available, threads operate in lockstep and write the results.
Notice: that the above picture is common to a warp. All threads in a warp walk in lockstep. They all call tl.load at the same time and they all wait together. However, remaining warps in the block can be on a different line in the kernel function, depending on how fast their data becomes available. This trick is utilized by FlashAttenion-3 in H100 GPUs, wherein a few warps send instructions to load the data, while others work on doing maths.
Programming Triton
With the knowledge of the hardware, we can jump straight into manipulating the hardware to achieve compute efficiency. Let’s break down the syntax of kernel_fn and kernel_fn[grid](args):
kernel_fn
The function kernel_fn has two types of arguments:
- variables: which might be different for different calls to the kernel. When a tensor is passed in the wrapper, Triton will actually pass the base pointer (memory address) to the tensor. Spatial calculation is to be done within the kernel function to ensure appropriate memory chunk is being read.
tl.constexpr: these are the values that are constants for thekernel_fn. These are not written to the registers of warps, rather these are translated to the machine code. For example, if there is a loop involved with number of iterations that are defined by atl.constexpr, this can result in a machine code with an explicitly unrolled loop. Therefore, for every new value oftl.constexpr, a new kernel is defined in the machine code.- Constraint in designing the block size: While deciding the block size, it is crucial to check if the data can fit the SRAM, if not, you will have to pay the cost of constantly loading and flushing the SRAM. And this is where autotuning functionality of triton can help by checking various combinations and using the one which has the minimum latency.
- Autotuning: This system enables Triton to autotune the kernel based on various layouts. For example,
@triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': 128}, num_warps=4), triton.Config({'BLOCK_SIZE': 1024}, num_warps=8), ], key=['n_elements'] ) @triton.jit def add_kernel(...): ...
grid
A usual way to define how kernel is launched using the grid is:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
The way that Triton reads it is as follows:
- Triton will read the line
grid=first and store the function definition - Then it will read the second line
add_kernel[grid]...and note that the function has some constants that are being passed. It will already read from the function definition that there is atl.constexprso it puts it in a dictionary. - Then it computes the grid dimensions, which can be 1D (e.g., (10, )), 2D (e.g., (5, 2)), and so on. This distinction is purely for software. For example, it might be easier to refer to a 2D grid while dealing with matrices.
- As discussed above, on the hardware side, it is essentially a 1D arrangement of streaming multiprocessors (also called Block). Each block further has Warps, a collection of 32 threads. Typically, there are 4 warps each block.
- Once the grid is computed, the
kernel_fnis launched on total number of blocks requested. Each block has a program id, which we can access usingtl.program_id(axis=0). This will denote, out of the grid requested, which block is thiskernel_fnexecuting on.tl.program_id(axis=0)is accessing the register (memory address), as a result, we can only read one value at a time.- If we want the y coordinate, we will have to ask
tl.program_id(1).
- Finally, we use the location of the block in the grid to determine how to slice and dice the data and perform maths on it.
args
For each tensor that we specify in this line, triton passes the memory address of the very first element. The spatial calculation that is done in each block is based upon these base pointers.
Spatial Calculations to iterate through memory
Performing spatial calculation is the most important part of writing the kernel function. This ensures that appropriate data is read from the memory. To understand how to manipulate memory addresses to get the correct data, we need to understand how to iterate through the memory addresses. There are two options here:
- Manual: We do the spatial maths to locate the memory addresses. In our first example above, we used the pointer and offsets to load the data. This is fine, but it gets messy when we have to loop over a sliding window.
- Automatic Sliding Windows:
tl.make_block_ptrcreates a pointer to iterate through the memory in chunks.tl.loadloads the data from the HBM to the thread registers, where math is done on all of this data at once.
tl.make_block_ptr takes the following arguments are required to define this:
base: this is the address of the memory where the data begins. Remember, memory is a 1D layout, therefore if we are representing it as 2D matrix, we need to take care of strides.shape: This is the overall shape of the data structure. This defines the boundaries so that data from outside the limits is not accessed.strides: This tells triton how to navigate this memory, i.e., how many addresses to move to advance to the next element in that particular dimension.offsets: It defines how much farther from the base pointer should one move to obtain the current block. Remember that each kernel is launched on a block, as a result, these offsets define the piece of data on which the block operates.block_shape: often, these pointers are used as a sliding window to pass over the data layout. This argument defines that shape.order: what order is the original data format in.
What happens when BLOCK_SIZE or the data requested is more than that can be fitted on to the thread registers?: The triton compiler will look at the data requested at the compile time, and if its more than what thre thread registers can hold, it will raise an error
Performance Testing using triton.testing.do_bench
The reason we use triton.testing.do_bench instead of a simple Python time.time() loop is specifically because it automatically flushes the L2 cache between measurements.
Internally, do_bench typically does something like this before every iteration:
-
Allocates a large “garbage” tensor (larger than the GPU’s L2 cache size).
-
Writes random data to it (forcing the L2 cache to fill up with garbage and evict your kernel’s data).
-
Runs your kernel.
This ensures that every single run is a “Cold” run, forcing the GPU to fetch your data from the main HBM memory every time. This accurately simulates a real training step where new data is constantly streaming in.
Let’s look at some examples:
- Add using manual memory management
- Vector add using automatic memory management
- Matmul using automatic memory management
Example 1: Add Kernel Using Manual Memory Management
Let’s revisit our add kernel from earlier and analyze its performance characteristics.
import torch
import triton
import triton.language as tl
device = torch.device("cuda")
# See the kernel code from earlier example above
# Using BLOCK_SIZE=1024 for testing
x = torch.rand(1000, device=device)
y = torch.rand(1000, device=device)
latency_ms = triton.testing.do_bench(lambda : add(x, y), warmup=25, rep=100)
print(f"Latency: {latency_ms} ms")
Latency: 0.006688000168651342 ms
Observations:
- The Triton kernel achieves very low latency for this simple operation
- PyTorch’s version is typically 10x slower due to Python overhead in interpreting inputs and allocating output tensors
- This slowness is merely due to CPU overhead in memory management, not GPU compute efficiency
- When using pre-allocated output buffers, PyTorch’s performance is similar to our custom kernel
Time for Data Transfer: How much time does it take to transfer data from HBM to thread registers (for a and b), and how much time does it take to transfer the results (c) back to the HBM?
For N = 2**21 floats, the total size of the tensor is 2**23 floats, i.e., 8MB. For Ampere GPUs, the bandwidth is 1500 GB/s, this amounts to transferring 8MB at the speed of 1500 GB/s. Thus, the total time it takes to do this work \( = \frac{38}{1500(2^{10})} = 0.015 \text{ ms} \).
Example 2: Vector Add Using Automatic Memory Management
The previous example used manual pointer arithmetic. Now let’s see how tl.make_block_ptr simplifies memory management while achieving similar performance.
import torch
import triton
import numpy as np
import timeit
import triton.language as tl
@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr, N, BLOCK_SIZE: tl.constexpr):
# Find the program id of this particular block, i.e., position in the grid
pid = tl.program_id(axis=0)
# Memory pointers are managed by Triton
a_block_ptr = tl.make_block_ptr(
base=a_ptr, shape=(N,), strides=(1,),
offsets=(pid*BLOCK_SIZE,), block_shape=(BLOCK_SIZE,), order=(0,),
)
b_block_ptr = tl.make_block_ptr(
base=b_ptr, shape=(N,), strides=(1,),
offsets=(pid*BLOCK_SIZE,), block_shape=(BLOCK_SIZE,), order=(0,)
)
c_block_ptr = tl.make_block_ptr(
base=c_ptr, shape=(N,), strides=(1,),
offsets=(pid*BLOCK_SIZE,), block_shape=(BLOCK_SIZE,), order=(0,)
)
# Load the data from HBM on to the thread registers
a_chunk = tl.load(a_block_ptr)
b_chunk = tl.load(b_block_ptr)
# Do Math
c_chunk = a_chunk + b_chunk
# Store this data back in the HBM
tl.store(c_block_ptr, c_chunk)
def vector_add(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
N = x.shape[0]
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
vector_add_kernel[grid](x, y, output, N, BLOCK_SIZE=1024)
return output
N = 2**21
x = torch.randn(N, device='cuda')
y = torch.randn(N, device='cuda')
latency = triton.testing.do_bench(lambda: vector_add(x,y)) # outputs ms
print(f"Latency (ms): {latency}")
## How does it look with Pytorch default kernels?
times = []
for _ in range(10):
x = torch.randn(N, device='cuda')
y = torch.randn(N, device='cuda')
torch.cuda.synchronize()
start = timeit.default_timer()
o = x + y
torch.cuda.synchronize()
times.append(timeit.default_timer() - start)
print(f"Manual time benchmark: {np.mean(times) * 1000: 0.4f}")
## Prove that the above manual time benchmark is merely the CPU Overhead
out = torch.empty_like(x, device='cuda')
times = []
for _ in range(10):
x = torch.randn(N, device='cuda')
y = torch.randn(N, device='cuda')
torch.cuda.synchronize()
start = timeit.default_timer()
torch.add(x, y, out=out)
torch.cuda.synchronize()
times.append(timeit.default_timer() - start)
print(f"Manual time benchmark (with out tensor): {np.mean(times) * 1000: 0.4f}")
Latency (ms): 0.032200057059526443
Manual time benchmark (ms; without out tensor): 0.8297
Manual time benchmark (ms; with out tensor): 0.0270
Example 3: Matrix Multiplication with 2D Grids
Matrix multiplication demonstrates more advanced Triton concepts:
- 2D grid programming (using both
pid_mandpid_n) - Tiled computation to fit data in SRAM
- Sliding window pattern with
advance() - Accumulation pattern for reduction operations
Important Note: If BLOCK_SIZE is too large to fit in SRAM (~164-228 KB), the kernel will hang. This example uses BLOCK_SIZE=64 to balance performance and memory constraints.
import torch
import numpy as np
import timeit
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE: tl.constexpr):
# A: MxN; B: NxK; C:MxN; B is transposed here
# Each block outputs a grid of C.
# We make this grid by sliding windows over the rows in A and cols in B
# For each block in A and B
# Location: Where is the block located in the grid
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# Mark the pointers
a_block_ptr = tl.make_block_ptr(
base=a_ptr, shape=(M, K), strides=(K, 1),
offsets=(pid_m*BLOCK_SIZE, 0),
block_shape=(BLOCK_SIZE, BLOCK_SIZE), order=(1, 0),
)
b_block_ptr = tl.make_block_ptr(
base=b_ptr, shape=(N, K), strides=(K, 1),
offsets=(pid_n*BLOCK_SIZE, 0),
block_shape=(BLOCK_SIZE, BLOCK_SIZE), order=(1, 0),
)
c_block_ptr = tl.make_block_ptr(
base=c_ptr, shape=(M, N), strides=(N, 1),
offsets=(pid_m*BLOCK_SIZE, pid_n*BLOCK_SIZE),
block_shape=(BLOCK_SIZE, BLOCK_SIZE), order=(1, 0),
)
accumulator = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE):
# LOAD
a_chunk = tl.load(a_block_ptr)
b_chunk = tl.load(b_block_ptr)
# DO MATH
accumulator = tl.dot(a_chunk, b_chunk, accumulator) # fused multiply-add
# Move the window
a_block_ptr = a_block_ptr.advance((0, BLOCK_SIZE))
b_block_ptr = b_block_ptr.advance((0, BLOCK_SIZE))
tl.store(c_block_ptr, accumulator)
def matmul(x: torch.tensor, y: torch.tensor):
M = x.shape[0]
N = y.shape[0]
K = x.shape[1]
out = torch.empty((M, N), device=x.device)
grid = lambda meta: (
triton.cdiv(M, meta['BLOCK_SIZE']),
triton.cdiv(N, meta['BLOCK_SIZE'])
)
matmul_kernel[grid](x, y, out, M, N, K, BLOCK_SIZE=64)
return out
M, N, K = 4096, 4096, 4096
x = torch.randn(M, K, device='cuda')
y = torch.randn(N, K, device='cuda')
latency = triton.testing.do_bench(lambda : matmul(x, y))
print(f"Latency (ms): {latency}")
Latency (ms): 2.7577507495880127
Key Takeaways
Writing efficient Triton kernels requires understanding the interplay between:
- Hardware constraints: SRAM size, memory bandwidth, warp execution
- Memory management: Coalesced access, minimizing HBM traffic
- Compute patterns: Tiling, fusion, and maximizing arithmetic intensity
The three examples progressively demonstrated:
- Manual memory management with explicit offsets
- Automatic memory management with block pointers
- Multi-dimensional grids and tiled computation
For production kernels, always:
- Use
@triton.autotuneto find optimal block sizes - Profile with
do_benchto measure real performance - Ensure data fits in SRAM to avoid thrashing
Prateek Gupta