56

I want to understand how pin_memory in Dataloader works.

According to the documentation:

pin_memory (bool, optional) – If True, the data loader will copy tensors into CUDA pinned memory before returning them.

Below is a self-contained code example.

import torchvision
import torch

print('torch.cuda.is_available()', torch.cuda.is_available())
train_dataset = torchvision.datasets.CIFAR10(root='cifar10_pytorch', download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, pin_memory=True)
x, y = next(iter(train_dataloader))
print('x.device', x.device)
print('y.device', y.device)

Producing the following output:

torch.cuda.is_available() True
x.device cpu
y.device cpu

But I was expecting something like this, because I specified flag pin_memory=True in Dataloader.

torch.cuda.is_available() True
x.device cuda:0
y.device cuda:0

Also I run some benchmark:

import torchvision
import torch
import time
import numpy as np

pin_memory=True
train_dataset =torchvision.datasets.CIFAR10(root='cifar10_pytorch', download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, pin_memory=pin_memory)
print('pin_memory:', pin_memory)
times = []
n_runs = 10
for i in range(n_runs):
    st = time.time()
    for bx, by in train_dataloader:
        bx, by = bx.cuda(), by.cuda()
    times.append(time.time() - st)
print('average time:', np.mean(times))

I got the following results.

pin_memory: False
average time: 6.5701503753662

pin_memory: True
average time: 7.0254474401474

So pin_memory=True only makes things slower. Can someone explain me this behaviour?

Ryan
  • 5
  • 3
Ivan Belonogov
  • 719
  • 1
  • 6
  • 8
  • I've edited my answer to respond to your benchmark. Next time please leave a comment, because it's only by chance that I noticed your question has been updated – Jatentaki Apr 08 '19 at 15:17

1 Answers1

83

The documentation is perhaps overly laconic, given that the terms used are fairly niche. In CUDA terms, pinned memory does not mean GPU memory but non-paged CPU memory. The benefits and rationale are provided here, but the gist of it is that this flag allows the x.cuda() operation (which you still have to execute as usually) to avoid one implicit CPU-to-CPU copy, which makes it a bit more performant. Additionally, with pinned memory tensors you can use x.cuda(non_blocking=True) to perform the copy asynchronously with respect to host. This can lead to performance gains in certain scenarios, namely if your code is structured as

  1. x.cuda(non_blocking=True)
  2. perform some CPU operations
  3. perform GPU operations using x.

Since the copy initiated in 1. is asynchronous, it does not block 2. from proceeding while the copy is underway and thus the two can happen side by side (which is the gain). Since step 3. requires x to be already copied over to GPU, it cannot be executed until 1. is complete - therefore only 1. and 2. can be overlapping, and 3. will definitely take place afterwards. The duration of 2. is therefore the maximum time you can expect to save with non_blocking=True. Without non_blocking=True your CPU would be waiting idle for the transfer to complete before proceeding with 2..

Note: perhaps step 2. could also comprise of GPU operations, as long as they do not require x - I am not sure if this is true and please don't quote me on that.

Edit: I believe you're missing the point with your benchmark. There are three issues with it

  1. You're not using non_blocking=True in your .cuda() calls.
  2. You're not using multiprocessing in your DataLoader, which means that most of the work is done synchronously on main thread anyway, trumping the memory transfer costs.
  3. You're not performing any CPU work in your data loading loop (aside from .cuda() calls) so there is no work to be overlaid with memory transfers.

A benchmark closer to how pin_memory is meant to be used would be

import torchvision, torch, time
import numpy as np
 
pin_memory = True
batch_size = 1024 # bigger memory transfers to make their cost more noticable
n_workers = 6 # parallel workers to free up the main thread and reduce data decoding overhead
train_dataset =torchvision.datasets.CIFAR10(
    root='cifar10_pytorch',
    download=True,
    transform=torchvision.transforms.ToTensor()
)   
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    pin_memory=pin_memory,
    num_workers=n_workers
)   
print('pin_memory:', pin_memory)
times = []
n_runs = 10

def work():
    # emulates the CPU work done
    time.sleep(0.1)

for i in range(n_runs):
    st = time.time()
    for bx, by in train_dataloader:
       bx, by = bx.cuda(non_blocking=pin_memory), by.cuda(non_blocking=pin_memory)
       work()
   times.append(time.time() - st)
print('average time:', np.mean(times))

which gives an average of 5.48s for my machine with memory pinning and 5.72s without.

Jatentaki
  • 11,804
  • 4
  • 41
  • 37
  • does this mean extra RAM usage? When should we NOT use it? thanks – Shihab Shahriar Khan Apr 08 '19 at 13:59
  • I don't know the technical details and exact consequences. I don't think any extra RAM is used, but since it can't be paged-out, the OS may not be able to page out your program and OOM in a situation it would normally be able to recover from. – Jatentaki Apr 08 '19 at 14:09
  • Do you know the expected behaviour of `.to(non_blocking=True)` when `pin_memory==False`? – user27182 Sep 18 '19 at 11:03
  • What I don't get is, if the the `.cuda` operation comes side by side with the CPU operation. How can we guarantee that the `x` sent to `.cuda` is the processed `x`, not the original one? – R. Zhu Feb 10 '20 at 02:07
  • 1
    It's not extra memory usage, but it's a block of memory that the OS can't move around, swap to disk if memory is running low, etc. So it makes the OS's job more difficult, and there's a limit to how much memory you can pin. – Christian Hudon Aug 18 '20 at 15:06
  • Thank you for the answer. It is quite non-intuitive. To use `pin_memory`, the tensor has to be on CPU. Even when the `pinned_memory_device` is listed as `cuda`, the data is still on cpu. When the data is retrieved from the DataLoader, use `data.to(device)` to actually move it to GPU. Together with the use of `num_workers > 0` in the DataLoader, this does result in a speedup from GPU processing. The pinning and the num_workers apparently allow the `to(device)` action to be completed efficiently. – rodin Aug 29 '23 at 19:35