Building PyTorch DDP From Scratch
February 2, 2026This post is (hopefully) a first of a series of posts about distributed training. Today we will build a distributed data parallel (DDP) from scratch.
The complete implementation is available on GitHub.
Motivation: Why Build DDP From Scratch?
DDP is the first step towards distributed training. Often neglicted by many people because of its simplicity. You just import torch.distributed modify few lines of code and run your training script with torchrun:
Okay maybe it's a bit true, but we will see that there is more to it than that. In this post we will go from a naive implementation to an optimized one that can achieve up to 95% of the performance of PyTorch DDP.
Sharded Data Loading
In Distributed Data Parallel (DDP) training, data is typically sharded across ranks so that each process works on a different subset of the dataset. This sharding is handled by the DataLoader and its sampler, not by DDP itself.
In pytorch, this is commonly done using
torch.utils.data.distributed.DistributedSampler, here we will implement a custom sampler that does essentially the same thing, you can find the implementation here.
The implementation is quite straightforward, the important things to remember are that:
-
pytorch’s distributed sampler performs an interleaved (strided) slicing of
dataset indices. For example, with 4 ranks and 16 samples, rank 0 receives
[0, 4, 8, 12], rank 1 receives[1, 5, 9, 13], rank 2 receives[2, 6, 10, 14], and rank 3 receives[3, 7, 11, 15]. -
The
drop_lastparameter controls how the sampler handles cases where the dataset size is not divisible by the number of ranks. Ifdrop_last=True, extra samples are dropped. Ifdrop_last=False, the sampler pads the list of dataset indices by repeating from the beginning of the index list (after optional shuffling) so that the total number of samples is divisible by the number of ranks. For example, with 4 ranks and 15 samples (no shuffling), rank 0 receives[0, 4, 8, 12], rank 1 receives[1, 5, 9, 13], rank 2 receives[2, 6, 10, 14], and rank 3 receives[3, 7, 11, 0].