Pytorch GPU memory management in Jupyter

notes
Pytorch
Jupyter
Lessons learned about managing GPU memory when using Pytorch in Jupyter

TL;DR

To clean-up cached memory in Pytorch:

del x
with torch.no_grad():
    torch.cuda.empty_cache()

Note that torch.cuda.empty_cache() simply frees cached memory. For this to be truly effective, we must carefully handle variables and contexts, like we do in the previous example deleting variable x.

Also, remember to the context manager when running the network in inference mode, e.g.:

x, y = next(iter(dl))
with torch.no_grad():
    pred = net(x)

Running the code above without the context manager may result in difficulties freeing up the GPU memory.

To check the global namespace in a Jupyter environment:

%who

which will return all variables living in the global namespace.

A more advanced approach would involve checking the garbage collector for tracked variables and referrers:

import gc
objs_in_gpu = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda]
referrers = gc.get_referrers(objs_in_gpu[0])

Breakdown

  1. Python’s garbage collector is the mechanism used in Python to determine when a variable (and the corresponding memory) can be freed. A reference count mechanism is used to keep track of references to memory. Sometimes, reference count increases can go unnoticed and create issues since the garbage collector does not free the memory as expected.
  2. A common error is storing references to Pytorch-related variables attached to computation graphs (as described here). This will cause the computation graph to stay in memory and eventually trigger OOM errors or slow down processing due to inefficient memory management.
  3. A key pattern to avoid many of these errors is using context managers, which automatically handle reference counting and ensure that variables declared within the context are freed when the context is exited.
with torch.cuda.device(0):
    # Code
  1. Functions and exception blocks (try-except-finally) also serve this purpose and it is often a good idea to combine them in functions that perform heavy processing tasks which might be interrupted. Interruptions, if not handled properly, may also interfere with the reference counting mechanism. The following is a useful snippet
def func(params):
    try:
        # Heavy processing here
    except: # Use specific exception if expected
        torch.cuda.empty_cache() # Ensure cached memory is freed if exception triggers

Resources

  • Post in iifx.dev discussing several methods to handle Out-Of-Memory errors associated to Pytorch.
  • Post by Neerja Aggarwal describing OOM cases in the context of CUDA, Pytorch and Jupyter notebooks, and tips on avoiding them.
  • Official Pytorch documentation on memory management when using CUDA.
  • Post by Bartosz Mikulski discussing memory management in Jupyter notebooks.
  • Stackoverflow thread with additional discussions.