Shortcuts

Shard Optimizer States with ZeroRedundancyOptimizer

Note

ZeroRedundancyOptimizer is introduced in PyTorch 1.8 as a prototype feature. This API is subject to change.

In this recipe, you will learn:

What is ZeroRedundancyOptimizer?

The idea of ZeroRedundancyOptimizer comes from DeepSpeed/ZeRO project and Marian that shard optimizer states across distributed data-parallel processes to reduce per-process memory footprint. In the Getting Started With Distributed Data Parallel tutorial, we have shown how to use DistributedDataParallel (DDP) to train models. In that tutorial, each process keeps a dedicated replica of the optimizer. Since DDP has already synchronized gradients in the backward pass, all optimizer replicas will operate on the same parameter and gradient values in every iteration, and this is how DDP keeps model replicas in the same state. Oftentimes, optimizers also maintain local states. For example, the Adam optimizer uses per-parameter exp_avg and exp_avg_sq states. As a result, the Adam optimizer’s memory consumption is at least twice the model size. Given this observation, we can reduce the optimizer memory footprint by sharding optimizer states across DDP processes. More specifically, instead of creating per-param states for all parameters, each optimizer instance in different DDP processes only keeps optimizer states for a shard of all model parameters. The optimizer step() function only updates the parameters in its shard and then broadcasts its updated parameters to all other peer DDP processes, so that all model replicas still land in the same state.

How to use ZeroRedundancyOptimizer?

The code below demonstrates how to use ZeroRedundancyOptimizer. The majority of the code is similar to the simple DDP example presented in Distributed Data Parallel notes. The main difference is the if-else clause in the example function which wraps optimizer constructions, toggling between ZeroRedundancyOptimizer and Adam optimizer.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

The output is shown below. When enabling ZeroRedundancyOptimizer with Adam, the optimizer step() peak memory consumption is half of vanilla Adam’s memory consumption. This agrees with our expectation, as we are sharding Adam optimizer states across two processes. The output also shows that, with ZeroRedundancyOptimizer, the model parameters still end up with the same values after one iterations (the parameters sum is the same with and without ZeroRedundancyOptimizer).

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources