ZeRO stage-1
Exploring ZeRO
This post explores the use of ZeRO optimizer from PyTorch. This is useful when training with DDP and you are bottlenecked by GPU memory. ZeRO optimizer shards the optimizer states across devices, leading to reduced memory usage per device. The ZeRO paper goes into details. Here we explore ZeRO stage-1.
There is really no reason not to shard the optimizer states. It’s relatively straightforward to use with the ZeRO API. Note that there are some limitations, which will go away when we migrate to FSDP. However, the point of this post is to show that adding a few lines of code can already be quite beneficial.
Before, going into the code, let’s review what it means to shard the optimizer states and how much memory we can expect to save.
I will use the small gpt config from nanoGPT, which has \(124M\) params, as an example. If you load it up and do a training run, you will see that the memory usage is around \(12.3\) GiB.
Memory usage
Let’s break down this memory usage. A good reference is here. GPU memory comprises the following:
- M_p: The model parameters
- M_p32: A copy of the parameters in full precision, when using mixed precision training
- M_g: Gradients
- M_opt: Optimizer states (momentum in SGD, momentum and variance in Adam)
- M_act: Activations
- Other things like cuda kernels and temporary buffers
The figure below depicts the memory usage:
Memory usage breakdown.
Apart from the activations, the memory for the components can be computed in an architecture agnostic manner. Let’s use our observations for the total memory usage of the \(124M\) model and get the memory used by activations (and the overhead things). This will let us estimate the memory usage after sharding the optimizer states.
Total parameters: N = 124_000_000
GPU memory per device: M = 12.3 * 2 ** 30 bytes
M_p = 2 * N # It takes two bytes to store 1 parameter in float16 / bfloat16
M_g = 2 * N;
M_p32 = 4 * N; # in mixed precision training, we keep a copt of the weights in float32
M_opt = 8 * N
The memory for activations is:
M_act = M - M_p - M_g - M_p32 - M_opt
M_act = 10.49 * 2 ** 30 bytes
If we have d GPUs, and we shard the optimizer states across them, the expected memory usage per GPU can be calculated as:
M_zero1 = M_act + M_p + M_g + M_p32 + (M_opt / d)
M_zero = 11.53
The memory savings is entirely dependent on d. For the \(124M\) model, around 6% memory savings are to be expected.
After sharding, the optimizer states are now split across the devices. In the following diagram, for 2 devices, the shaded regions show the optimizer states which are no longer replicated across the two GPUs.
Optimizer state sharding across GPUs.
So how do we do this in code? PyTorch makes it really easy to instantiate the ZeRO version of an optimizer.
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
optimizer = ZeRO(
model.parameters(),
optimizer_class=torch.optim.AdamW,
parameters_as_bucket_view=True,
overlap_with_ddp=False, # we will get to this later
lr=learning_rate,
betas=betas,
)
The optimizer parameters are passed as keyword arguments to the ZeRO constructor. Here we pass the lr and betas for AdamW.
Let us now instantiate models of varying sizes and check the memory usage with and without sharding the optimizer states. The sequence length is set to \(1024\) tokens and the batch size is \(12\).
| Memory | 124M | 353M | 772M |
|---|---|---|---|
| without sharding | 12.34 GB | 23.20 GB | OOM |
| with sharding | 11.15 GB | 21.26 GB | 35.93 GB |
We see that the actual memory usage is what our calculation above predicted. Note that if we keep the batch size fixed, the \(772M\) model is ‘out-of-memory’ without sharding.
So is this a free lunch? Since we split the optimizer state across different devices, the devices need to communicate and update the copy of their weights accordingly. So whenever an optimization step happens, the weights need to be synced.
Communication cost
Let’s expand the table from above and look at the time it takes for each optimization step, on 8 * A100s.
| Metric | ZeRO | 124M | 353M | 772M |
|---|---|---|---|---|
| Memory | ✗ | 12.34 GB | 23.20 GB | OOM |
| Memory | ✓ | 11.15 GB | 21.26 GB | 35.93 GB |
| Time/step | ✗ | 345.83 ms | 907.96 ms | — |
| Time/step | ✓ | 367.77 ms | 962.39 ms | 1943.65 ms |
| MFU | ✗ | 48.61% | 52.96% | — |
| MFU | ✓ | 46.05% | 49.95% | 52.68% |
We see that the way we have used ZeRO, it is slightly slower than the non-sharded version. However, the good news is that we can overlap the communication step with the DDP communication. This needs a bit of extra work. First, set overlap_with_ddp=True when creating the optimizer. Then, you have to add specific hooks, and not call the optimizer step directly.1
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
hook_with_zero_step,
hook_with_zero_step_interleaved
)
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as dh
optimizer = ZeRO(
model.parameters(),
optimizer_class=torch.optim.AdamW,
parameters_as_bucket_view=True,
overlap_with_ddp=True
lr=learning_rate,
betas=betas,
)
zero_hook = hook_with_zero_step(dh.allreduce_hook, model, optimizer)
model.register_comm_hook(state=None, hook=zero_hook)
After enabling overlap, the computation times are almost the same as without sharding.
| Metric | No sharding | ZeRO - no overlap | ZeRO + overlap |
|---|---|---|---|
| Memory. | 35.00 GB | 28.52 GB | 28.52 GB |
| Time/step | 2251.12 ms | 2358.40 ms | 2275.07 ms |
| MFU | 54.58% | 52.31% | 54.35% |
Saving the optimizer state
Since the optimizer states are sharded, before saving the state_dict, we need to synchronize the states from all the devices. In PyTorch this can be done as:
# called on all ranks, to=0 indicates that rank 0
# will have the consolidated optimizer states
optimizer.consolidate_state_dict(to=0)
if torch.distributed.get_rank() == 0:
torch.save(optimizer.state_dict(), 'optim_state.pth')
This is a blocking step and it takes some time to sync among the devices. The time taken depends on the hardware setup, specifically the device interconnect.
Summary
The optimizer states take up a significant amount of memory. However they can be sharded, and with the right implementation details, the communication cost involved is negligible. This enables training in some configurations which would either consume too much memory per device. As such, if an efficient implementation is possible, sharding should be enabled by default.
However, the implementation of ZeRO stage-1 has some limitations. For example, we are not able to use fused AdamW and the code to add communication hooks is a bit cumbersome. The gradients are also not sharded. Fortunately, we have FSDP2 which provides ZeRO stage-1 and stage-2. We will explore this in the next post.
References:
Enjoy Reading This Article?
Here are some more articles you might like to read next: