Simple GPU memory allocation experiments every ML Engineer should do

Akhilez
12 min readAug 23, 2024

--

Given a model architecture, dtype, input shape and optimizer, can you figure out how much of GPU memory will be required for a forward and backward pass? This is important to know because GPU memory can be expensive.

To answer that question, we need to break the pipeline into fundamental components and understand the memory requirements from bottom-up. The following experiments, which you can run on Google Colab, will help you understand the core concepts. We’ll use PyTorch for all the experiments below.

Reservation vs Allocation

PyTorch reserves more memory, but allocates only what’s needed. This is done so that when more memory is needed, it can be allocated quickly instead of costly reservation operations. For our purposes, we will only care about memory allocation and not reservation.

def test_reservation_vs_allocation():
print(f"Base memory reserved: {torch.cuda.memory_reserved(device_id)}")
print(f"Base memory allocated: {torch.cuda.memory_allocated(device_id)}")

# Allocate some memory
x = torch.randn((1024,), dtype=torch.float32, device=device)
print(f"Memory after allocation (reserved): {torch.cuda.memory_reserved(device_id)}")
print(f"Memory after allocation (allocated): {torch.cuda.memory_allocated(device_id)}")

# Cleanup
del x
print(f"Memory after cleanup (reserved): {torch.cuda.memory_reserved(device_id)}")
print(f"Memory after cleanup (allocated): {torch.cuda.memory_allocated(device_id)}")

torch.cuda.empty_cache()
print(f"Memory after empty_cache (reserved): {torch.cuda.memory_reserved(device_id)}")
print(f"Memory after empty_cache (allocated): {torch.cuda.memory_allocated(device_id)}")

"""
Output:

Base memory reserved: 0
Base memory allocated: 0
Memory after allocation (reserved): 2097152
Memory after allocation (allocated): 4096
Memory after cleanup (reserved): 2097152
Memory after cleanup (allocated): 0
Memory after empty_cache (reserved): 0
Memory after empty_cache (allocated): 0
"""

When deleting the variable x or when x goes out of scope, x’s memory is deallocated, but it’s still reserved for future use. Only when torch.cuda.empty_cache() is called does it release reserved memory.

Note that torch.cuda.memory_allocated() will return the memory allocated by PyTorch on this process. If there’s another process that is using some of the GPU memory, this will return 0. To get the true GPU memory usage, we can use the function below. We won’t need this for the experiments, but something you might want to play with.

import subprocess


def get_gpu_memory_used(gpu_id):
"""
Returns the amount of memory used on the specified GPU in bytes.

Parameters:
gpu_id (int): The ID of the GPU (e.g., 0 for "cuda:0", 1 for "cuda:1").

Returns:
int: The amount of memory used on the GPU in bytes.
"""
try:
# Run the nvidia-smi command to get memory usage
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader", f"--id={gpu_id}"],
stdout=subprocess.PIPE,
text=True
)

# Get the used memory in MiB from the result
used_memory_mib = int(result.stdout.strip())

# Convert MiB to bytes (1 MiB = 1024 * 1024 bytes)
used_memory_bytes = used_memory_mib * 1024 * 1024

return used_memory_bytes

except Exception as e:
print(f"Error occurred: {e}")
return None

Dtypes

float32 requires 4 bytes of memory, bfloat16 requires 2 bytes and so on. Let’s plot the memory needed for a few dtypes.

Figure 1: Memory allocation for different data types
def test_dtype_memory_allocation():
dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int64, torch.uint8, torch.int8, torch.uint16]
memories = []
for dtype in dtypes:
base_memory = get_gpu_memory_used(device_id)
x = torch.ones((1024,), dtype=dtype, device=device)
memory_after_allocation = get_gpu_memory_used(device_id)
memories.append((memory_after_allocation - base_memory) // 1024)
del x
torch.cuda.empty_cache()
fig = plt.figure(figsize=(7, 4))
fig.set_tight_layout(True)
plt.bar([str(d) for d in dtypes], memories)
plt.xlabel("Data type")
plt.ylabel("Bytes per element")
plt.title("Memory allocation for different data types")
plt.xticks(rotation=45)
plt.show()

Memory Chunks

Memory is allocated in chunks of 512 bytes. When a tensor is created, it is allocated in the next available chunk. For a float32 tensor of shape (800,), instead of 800 * 4 = 3200 bytes, 3584 (512 * 7) bytes are allocated.

Figure 2: Memory allocation for different tensor sizes.
def test_memory_allocation_relationship():
"""
For different sizes of tensors, check the memory allocated on GPU.
"""
memories = []
sizes = 1050
for i in tqdm(range(sizes)):
base_memory = get_gpu_memory_used(device_id)
x = torch.randn((i,), dtype=torch.float32, device=device)
memory_after_allocation = get_gpu_memory_used(device_id)
memories.append(memory_after_allocation - base_memory)
del x
torch.cuda.empty_cache()
plt.plot(memories)
plt.xlabel("Size of float32 tensor")
plt.ylabel("Memory allocated (bytes)")
plt.title("Memory allocation for different tensor sizes")
plt.show()

Trainable Params (Single Linear Layer Forward)

Next, we’ll look at a single linear layer. We’ll do a forward pass, and figure out the memory needed.

def test_single_linear_layer_forward_allocation():
# Disable cublas
# import os; os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":0:0"

print(f"Base memory: {torch.cuda.memory_allocated(device_id)}")

model = nn.Linear(256, 250, device=device, dtype=torch.float32)
print(f"Memory after model allocation: {torch.cuda.memory_allocated(device_id)}")

x = torch.randn((1, 256,), dtype=torch.float32, device=device)
print(f"Memory after input allocation: {torch.cuda.memory_allocated(device_id)}")

y = model(x)
final_memory = torch.cuda.memory_allocated(device_id)
print(f"Memory after forward pass: {final_memory}")

# Memory calculations
w_mem = len(model.weight.flatten()) * model.weight.dtype.itemsize
# Get the higher multiple of 512
w_mem_as_chunks = (w_mem + 511) // 512 * 512
print(f"{model.weight.shape=}, {w_mem=}, {w_mem_as_chunks=}")

b_mem = len(model.bias) * model.bias.dtype.itemsize
b_mem_as_chunks = (b_mem + 511) // 512 * 512
print(f"{model.bias.shape=}, {b_mem=}, {b_mem_as_chunks=}")

x_mem = (len(x.flatten()) * x.dtype.itemsize + 511) // 512 * 512
y_mem = (len(y.flatten()) * y.dtype.itemsize + 511) // 512 * 512
print(f"{x_mem=}, {y_mem=}")

total_memory_expected = w_mem_as_chunks + b_mem_as_chunks + x_mem + y_mem

cublas_workspace_size = 8519680
memory_with_cublas = total_memory_expected + cublas_workspace_size
print(f"{total_memory_expected=}, {memory_with_cublas=}")

assert final_memory == memory_with_cublas

del model, x, y
torch.cuda.empty_cache()
print(f"Memory after cleanup: {torch.cuda.memory_allocated(device_id)}")

torch._C._cuda_clearCublasWorkspaces()
print(f"Memory after clearing cublas workspace: {torch.cuda.memory_allocated(device_id)}")

"""
Output:
Base memory: 0
Memory after model allocation: 257024
Memory after input allocation: 258048
Memory after forward pass: 8778752
model.weight.shape=torch.Size([250, 256]), w_mem=256000, w_mem_as_chunks=256000
model.bias.shape=torch.Size([250]), b_mem=1000, b_mem_as_chunks=1024
x_mem=1024, y_mem=1024
total_memory_expected=259072, memory_with_cublas=8778752
Memory after cleanup: 8519680
Memory after clearing cublas workspace: 0
"""

model has a float32 weight matrix of shape (256, 250) which takes up (256 * 250 * 4) = 256,000 bytes which exactly matches with a multiple of memory chunk size 512 (512 * 500 = 256,000). The bias however has 250 float32 numbers which amounts to (250 * 4) = 1000 bytes. And the higher multiple of 512 is 2, (512 * 2) = 1024 bytes. x and y are tensors of shape (256,) so they occupy 1024 bytes each. Total memory = weight + bias + x + y

When we add up everything, we should get 259,072 bytes (256,000 + 1024 + 1024 + 1024). However, the size we observe is 8,778,752 bytes. This extra 8,519,680 bytes come from allocating a cuBLAS workspace.

This is a memory space reserved for fast matrix multiplication operations. So for some matmul operations, a new block of 8,519,680 bytes is allocated. This size may vary based on the GPUs and python environment. More about it here. The cublas memory doesn’t go away when torch.cuda.empty_cache() is called. It requires torch._C._cuda_clearCublasWorkspaces() to actually clear it. You can set the environment variable os.environ[“CUBLAS_WORKSPACE_CONFIG”] = “:0:0” to disable cublas workspace. This could be one way to optimize for memory at the cost of slower execution, try it yourself and let me know in the comments :)

Gradients (Single Linear Layer Backward)

Use the same model, but also run loss.backward(). For simplicity, let’s just use loss = y.sum().

def test_single_linear_layer_backward_allocation():
print(f"Base memory: {torch.cuda.memory_allocated(device_id)}")

model = nn.Linear(256, 250, device=device, dtype=torch.float32)
x = torch.randn((1, 256,), dtype=torch.float32, device=device)
y = model(x)

print(f"Memory after forward pass: {torch.cuda.memory_allocated(device_id)}")
y.sum().backward()
final_memory = torch.cuda.memory_allocated(device_id)
print(f"Memory after backward pass: {final_memory}")

# Memory calculations
next_chunk = lambda n: (n + 511) // 512 * 512
units = model.weight.dtype.itemsize # 4 bytes for float32
mem = next_chunk(len(model.weight.flatten()) * units)
mem += next_chunk(len(model.bias) * units)
print(f"Excepted model memory: {mem}")

x_mem = next_chunk(len(x.flatten()) * units)
y_mem = next_chunk(len(y.flatten()) * units)
print(f"{x_mem=}, {y_mem=}")
mem += x_mem + y_mem

# Gradient memory
w_grad_mem = next_chunk(len(model.weight.grad.flatten()) * units)
b_grad_mem = next_chunk(len(model.bias.grad.flatten()) * units)
print(f"{model.weight.grad.shape=}, {w_grad_mem=}")
print(f"{model.bias.grad.shape=}, {b_grad_mem=}")
mem += w_grad_mem + b_grad_mem

mem += 2 * 8519680 # cublas_size doubled
print(f"Total memory expected: {mem}")
assert final_memory == mem

del model, x, y
torch.cuda.empty_cache()
print(f"Memory after cleanup: {torch.cuda.memory_allocated(device_id)}")

torch._C._cuda_clearCublasWorkspaces()
print(f"Memory after clearing cublas workspace: {torch.cuda.memory_allocated(device_id)}")

"""
Output:
Base memory: 0
Memory after forward pass: 8778752
Memory after backward pass: 17555456
Excepted model memory: 257024
x_mem=1024, y_mem=1024
model.weight.grad.shape=torch.Size([250, 256]), w_grad_mem=256000
model.bias.grad.shape=torch.Size([250]), b_grad_mem=1024
Total memory expected: 17555456
Memory after cleanup: 17039360
Memory after clearing cublas workspace: 0
"""

Now since every model parameter that has requires_grad=True will have a .grad member that stores gradients for the underlying tensor. So the size of the model doubles.

Notice how this time, 2 blocks of cublas workspace memory is allocated, assuming one for forward and one for backward. When exactly does a new block get allocated for cublas is still uncertain at this point.

Intermediate Tensors (Multi-Layer Feed Forward)

When the model is running in inference mode, there’s no autograd graph, no need to store intermediate tensors. So the amount of memory is simply summing up the memory of each layer.

However, in the training mode where we need to track the computational graph, it’s a bit different. When there’s multiple operations applied in serial, like in the case of a feedforward or any deep network, the autograd graph needs to remember the intermediate tensors of these operations. The storage requirements depend on the nature of their partial derivative operation. These intermediate tensors are cleared from memory during the backward pass. Let’s look at a few examples to understand better. x is the input and w is the parameter that requires grad (w.requires_grad = True).

  • x @ w will require no additional storage. The partial derivative x is already stored. But when x is some output like x = u * w1, then x need to be stored as well.
  • x + w will also not require storage as the partial derivative wrt w is 0.
  • (x * 2) @ w will require storing the operand x * 2 as it will be used to find the gradient.
  • (((x + 2) @ w1) + 3) * w2 is an interesting case that mimics 2 layers.
    - For partial derivative wrt w1, we need to store x + 2
    - For partial derivative wrt w2, we need to store ((x + 2) @ w1) + 3

Let’s look at the implementation of a deeper network:

def test_multi_layer_forward():
print(f"Base memory: {torch.cuda.memory_allocated(device_id)}")

inference_mode = False
n_layers = 1
model = nn.Sequential(*[
nn.Sequential(
nn.Linear(200, 100),
nn.ReLU(), # No trainable params
nn.Linear(100, 200),
nn.Sigmoid(), # No trainable params
)
for _ in range(n_layers)
]).to(device_id)
batch_size = 5
x = torch.randn((batch_size, 200), device=device_id)
with torch.inference_mode(inference_mode):
y = model(x)

final_memory = torch.cuda.memory_allocated(device_id)
print(f"Memory after forward pass: {final_memory}")

# Computed memory
next_chunk = lambda n: (n + 511) // 512 * 512
mem = 0
unit = model[0][0].weight.dtype.itemsize
for block in model:
for layer in block:
if isinstance(layer, nn.Linear):
mem += next_chunk(len(layer.weight.flatten()) * unit)
mem += next_chunk(len(layer.bias) * unit)
if not inference_mode:
# Gotta store the input
mem += next_chunk(layer.in_features * batch_size * unit)
mem += next_chunk(len(y.flatten()) * unit)
mem += 8519680 # cublas_size
if inference_mode:
mem += next_chunk(len(y.flatten()) * unit)
print(f"Total memory expected: {mem}")
assert final_memory == mem

In normalization layers like BatchNorm1d, LayerNorm, RMSNorm, there’s an operation on input x before multiplying with parameter w like (x — x.mean()) / (x.std() + 1e-6) * w. The operand (x — x.mean()) / (x.std() + 1e-6) is intermediate output that needs to be stored. However, there may be other states like running_mean, running_std or intermediate tensors in the forward() method that need to be accounted for. Some of these intermediate tensors are inaccessible to us, so we can’t be sure of what’s going on. This becomes more complex when batch sizes are included.

def test_layer_norm():
print(f"Base memory: {torch.cuda.memory_allocated(device_id)}")
x = torch.rand((10,), device=device_id)
w = torch.rand((10,), requires_grad=True, device=device_id)
# Layer Norm
y = (x - x.mean()) / (x.std() + 1e-6) * w
final_memory = torch.cuda.memory_allocated(device_id)
print(f"Memory after forward pass: {final_memory}")

# Memory calculations
next_chunk = lambda n: (n + 511) // 512 * 512
mem = next_chunk(len(x.flatten()) * x.dtype.itemsize)
mem += next_chunk(len(w.flatten()) * w.dtype.itemsize)
mem += next_chunk(len(y.flatten()) * y.dtype.itemsize)
mem += next_chunk(len(x.flatten()) * x.dtype.itemsize) # intermediate
print(f"Total memory expected: {mem}")
assert final_memory == mem

Backward pass is very similar, but there are a few changes:

  • Model size is doubled because of gradient storage.
  • Intermediate tensors are all cleared by the end.
  • A new cublas workspace is allocated.
def test_multi_layer_backward():
print(f"Base memory: {torch.cuda.memory_allocated(device_id)}")

n_layers = 1
model = nn.Sequential(*[
nn.Sequential(
nn.Linear(200, 100),
nn.ReLU(), # No trainable params
nn.Linear(100, 200),
nn.Sigmoid(), # No trainable params
)
for _ in range(n_layers)
]).to(device_id)
batch_size = 5
x = torch.randn((batch_size, 200), device=device_id)
y = model(x)
print(f"Memory after forward pass: {torch.cuda.memory_allocated(device_id)}")
y.sum().backward()
final_memory = torch.cuda.memory_allocated(device_id)
print(f"Memory after backward pass: {final_memory}")

# Computed memory
next_chunk = lambda n: (n + 511) // 512 * 512
mem = 0
unit = model[0][0].weight.dtype.itemsize
for block in model:
for layer in block:
if isinstance(layer, nn.Linear):
mem += next_chunk(len(layer.weight.flatten()) * unit) * 2 # Weights and gradients
mem += next_chunk(len(layer.bias) * unit) * 2 # Biases and gradients
# mem += next_chunk(layer.in_features * batch_size * unit) # Intermediate tensors are cleared
mem += next_chunk(len(y.flatten()) * unit)
mem += 2 * 8519680 # cublas_size doubled
mem += next_chunk(len(y.flatten()) * unit)
print(f"Total memory expected: {mem}")
assert final_memory == mem

Optimizers (Single Linear Layer Backprop)

First, let’s observe the memory allocation for a few optimization steps.

def test_single_linear_layer_with_optimizer():
# Disable cublas
import os; os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":0:0"

memory_timeline_real = []
add = lambda e: memory_timeline_real.append({"event": e, "memory": torch.cuda.memory_allocated(device_id)})
add("baseline")

in_size = 256
out_size = 250
batch_size = 100
model = nn.Linear(in_size, out_size, device=device, dtype=torch.float32)
add("model_allocation")

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
add("optimizer_init")

x = torch.randn((batch_size, in_size,), dtype=torch.float32, device=device)
add("input_allocation")

def step(n):
optimizer.zero_grad()
add(f"optim_zero_grad_{n}")

y = model(x)
add(f"forward_{n}")

y.sum().backward()
add(f"backward_{n}")

optimizer.step()
del y
add(f"optim_step_{n}")

for i in range(4):
step(i + 1)

# Bar chart with even name on x-axis and total_memory on y-axis
fig = plt.figure(figsize=(15, 7))
fig.set_tight_layout(True)
plt.ylim((0, 1_300_000))
plt.bar([event["event"] for event in memory_timeline_real], [event["memory"] for event in memory_timeline_real])
plt.xlabel("Event")
plt.ylabel("Total memory allocated (bytes)")
plt.title(f"Memory allocation during training ({type(optimizer)})")
plt.xticks(rotation=45)
plt.show()
Figure 3: Memory allocation at various stages during training with SGD optimizer
Figure 4: Memory allocation at various stages during training with Adam optimizer

Up until backward_1, we see the memory allocation as expected. When optimizer.step() ends, in this particular code, we delete y, so, that memory is released. However, we know that under the hood, optimizer grabs additional memory equal to the trainable params to update them and releases that memory after update. That is not shown in this graph. A more detailed graph over time can be seen in Figure 5 taken from https://pytorch.org/blog/understanding-gpu-memory-1/.

However, Adam has a first order moment and second order moment for every trainable parameter. So, it always keeps 2x model size in memory. This is the most memory-consuming part of the training in this code.

Figure 5: Memory allocation by time in milliseconds. Source: https://pytorch.org/blog/understanding-gpu-memory-1/

Now let’s try to manually calculate these memory requirements:

    # Memory calculations (continuing from previous code block)
units = model.weight.dtype.itemsize
memory_timeline = []
all_keys = ["trainable_params", "input", "output", "gradient", "intermediate_tensors", "optimizer_state"]
def update_memory(event: str, update: dict):
prev_state = memory_timeline[-1] if memory_timeline else {k: 0 for k in all_keys}
new_state = {k: prev_state.get(k, 0) + update.get(k, 0) for k in all_keys}
new_state["event"] = event
memory_timeline.append(new_state)
next_chunk = lambda n: (n + 511) // 512 * 512

update_memory("baseline", {})

# Model memory
model_mem = next_chunk(len(model.weight.flatten()) * units)
model_mem += next_chunk(len(model.bias) * units)
update_memory("model_allocation", {"trainable_params": model_mem})
update_memory("optimizer_init", {})

# Input memory
x_mem = next_chunk(len(x.flatten()) * units)
update_memory("input_allocation", {"input": x_mem})
update_memory("optim_zero_grad_1", {})

# Forward
y_mem = next_chunk(batch_size * out_size * units)
# Add any intermediate tensors here.
update_memory("forward_1", {"output": y_mem}) # , "intermediate_tensors": ...})

# Backward
grad_mem = next_chunk(len(model.weight.grad.flatten()) * units)
grad_mem += next_chunk(len(model.bias.grad.flatten()) * units)
# Clear any intermediate tensors here.
update_memory("backward_1", {"gradient": grad_mem}) # "intermediate_tensors": ...})

# Optimizer memory
if isinstance(optimizer, torch.optim.SGD):
# SGD has parameters in memory. They are cleared after each step.
optimizer_mem = 0
elif isinstance(optimizer, torch.optim.Adam):
# Adam has parameters and 2 momentum buffers. Parameters are cleared after each step.
optimizer_mem = 2 * model_mem
else:
raise
update_memory("optim_step_1", {"optimizer_state": optimizer_mem, "output": -y_mem})

for step in range(2, 5):
update_memory(f"optim_zero_grad_{step}", {"gradient": -grad_mem})
update_memory(f"forward_{step}", {"output": y_mem})
update_memory(f"backward_{step}", {"gradient": grad_mem})
update_memory(f"optim_step_{step}", {"output": -y_mem})

# Make totals
for event in memory_timeline:
event["total"] = sum([v for v in event.values() if isinstance(v, int)])

# Plot memory timeline
import pandas as pd
df = pd.DataFrame(memory_timeline, columns=all_keys + ["event"])
df.set_index("event", inplace=True, drop=True)
df.plot(kind='bar', stacked=True, figsize=(15, 7), ylim=(0, 1_300_000), xlabel="Event", ylabel="Total memory allocated (bytes)", title=f"Memory allocation expected ({type(optimizer)})")
plt.tight_layout()
plt.xticks(rotation=45)
plt.show()

# Compare the two timelines
for i, (real, expected) in enumerate(zip(memory_timeline_real, memory_timeline)):
assert real["memory"] == expected["total"], f"Memory mismatch at {real['event']}: {real['memory']} != {expected['total']}"
Figure 6: Segmentation of memory usage at different stages in training using SGD optimizer
Figure 7: Segmentation of memory usage at different stages in training using Adamoptimizer

After manually calculating the memory allocations, we match the observations. This time, we can actually see the segmentation of the memory allocation to various tensors. We see that the Adam’s state takes up twice the model size for example. We also see the gradients in red coming and going. As an excercise, you may try to add more layers to this model, add the intermediate tensors and delete them at appropriate times. It should create another segment within these bar graphs that represent intermediate tensors.

Putting things together

Combine every concept from above to answer the main question:

  • Trainable Parameters: Fixed model size
  • Memory chunking: It only comes in 512-byte chunks
  • Cublas Memory: One block for forward, one block for backward
  • Gradients: Same as the model size
  • Intermediate Tensors: Most trickiest
  • Optimizer: Allocates at least one times the model size

One elephant in the room is that we’ve only dealt with feed-forward layers, what about CNNs, Transformers, RNNs, etc? Now that you’ve got the hang of calculating memory requirements for feed-forward layers, I’m confident you’ll be able to tackle these on your own too!

--

--

Akhilez
Akhilez

Written by Akhilez

I am obsessed with Deep Learning 🧠, Productivity 👨🏻‍💻 and Space Exploration 🪐

No responses yet