Speeding up PyTorch models with multiple GPUs

A large proportion of machine learning models these days, particularly in NLP, are published in PyTorch. This article covers the following

  • Setting up a Google Cloud machine with PyTorch (for procuring a Google cloud machine use this link)

Setting Google Cloud machine with PyTorch

Running the scripts in this Github repo and following instructions should setup a Google cloud machine with drivers for single/multiple Nvidia GPUs

Testing parallelism on multi GPU machine

Enabling the model to handle data in parallel is literally one line of code

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model) #enabling data parallelism

The full code for the toy test is listed here. The output of this example (python multi_gpu.py ) on an 8 GPU machine is shown below:

The batch size is 32. So the first 7 GPUs process 4 samples, while the 8th GPU process the remaining 2.

Code changes to make model utilize multiple GPUs for training and inference

First we create a device handle that will be used below

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")n_gpu = torch.cuda.device_count()

Then for enabling data parallelism for both training and inference

model = Model(input_size, output_size)
if n_gpu > 1:

#enabling data parallelism

In addition enabling data parallelism, for training we would need to average the loss across GPUs

if n_gpu > 1:
loss = loss.mean()

Also data (tensors) need to be converted (CPU to Cuda) as follows. An example shown below

input_ids = input_ids.to(device)input_mask = input_mask.to(device)segment_ids = segment_ids.to(device)label_ids = label_ids.to(device)with torch.no_grad():       logits = model(input_ids, segment_ids, input_mask)

This link shows a detailed example

Github link for GPU drivers/PyTorch setup

Machine learning practitioner

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store