Speeding up PyTorch models with multiple GPUs

A large proportion of machine learning models these days, particularly in NLP, are published in PyTorch. This article covers the following

  • Setting up a Google Cloud machine with PyTorch (for procuring a Google cloud machine use this link)
  • Testing parallelism on multi GPU machines with a toy example
  • Code changes required to make model utilize multiple GPUs both for training and inference

Setting Google Cloud machine with PyTorch

Running the scripts in this Github repo and following instructions should setup a Google cloud machine with drivers for single/multiple Nvidia GPUs

Testing parallelism on multi GPU machine

Enabling the model to handle data in parallel is literally one line of code

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model) #enabling data parallelism

The full code for the toy test is listed here. The output of this example (python multi_gpu.py ) on an 8 GPU machine is shown below:

The batch size is 32. So the first 7 GPUs process 4 samples, while the 8th GPU process the remaining 2.

Code changes to make model utilize multiple GPUs for training and inference

First we create a device handle that will be used below

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")n_gpu = torch.cuda.device_count()

Then for enabling data parallelism for both training and inference

model = Model(input_size, output_size)
if n_gpu > 1:

#enabling data parallelism

In addition enabling data parallelism, for training we would need to average the loss across GPUs

if n_gpu > 1:
loss = loss.mean()

Also data (tensors) need to be converted (CPU to Cuda) as follows. An example shown below

input_ids = input_ids.to(device)input_mask = input_mask.to(device)segment_ids = segment_ids.to(device)label_ids = label_ids.to(device)with torch.no_grad():       logits = model(input_ids, segment_ids, input_mask)

This link shows a detailed example

Github link for GPU drivers/PyTorch setup

Machine learning practitioner

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store