I have a large ML code that I've been writing for a few months and I've started the process of try to parallelize the data side of things to work with multiple GPUs. To start, the code works perfectly when using a single GPU; the issue comes when using multiple GPUs.
The error is as follows: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
Snippets of relevant code can be found below:
Model file
import torch
from torch import nn
from functools import partial
import copy
from ..mlp import MLP
from ..basis import gaussian, bessel
from ..conv import GatedGCN
class Encoder(nn.Module):
"""ALIGNN/ALIGNN-d Encoder.
The encoder must take a PyG graph object `data` and output the same `data`
with additional fields `h_atm`, `h_bnd`, and `h_ang` that correspond to the atom, bond, and angle embedding.
The input `data` must have three fields `x_atm`, `x_bnd`, and `x_ang` that describe the atom type
(in onehot vectors), the bond lengths, and bond/dihedral angles (in radians).
"""
def __init__(self, num_species, cutoff, dim=128, dihedral=False):
super().__init__()
self.num_species = num_species
self.cutoff = cutoff
self.dim = dim
self.dihedral = dihedral
self.embed_atm = nn.Sequential(MLP([num_species, dim, dim], act=nn.SiLU()), nn.LayerNorm(dim))
self.embed_bnd = partial(bessel, start=0, end=cutoff, num_basis=dim)
self.embed_ang = self.embed_ang_with_dihedral if dihedral else self.embed_ang_without_dihedral
def embed_ang_with_dihedral(self, x_ang, mask_dih_ang):
cos_ang = torch.cos(x_ang)
sin_ang = torch.sin(x_ang)
h_ang = torch.zeros([len(x_ang), self.dim], device=x_ang.device)
h_ang[~mask_dih_ang, :self.dim // 2] = gaussian(cos_ang[~mask_dih_ang], start=-1, end=1,
num_basis=self.dim // 2)
h_cos_ang = gaussian(cos_ang[mask_dih_ang], start=-1, end=1, num_basis=self.dim // 4)
h_sin_ang = gaussian(sin_ang[mask_dih_ang], start=-1, end=1, num_basis=self.dim // 4)
h_ang[mask_dih_ang, self.dim // 2:] = torch.cat([h_cos_ang, h_sin_ang], dim=-1)
return h_ang
def embed_ang_without_dihedral(self, x_ang, mask_dih_ang):
cos_ang = torch.cos(x_ang)
return gaussian(cos_ang, start=-1, end=1, num_basis=self.dim)
def forward(self, data):
# Embed atoms
data.h_atm = self.embed_atm(data.x_atm)
# Embed bonds
data.h_bnd = self.embed_bnd(data.x_bnd)
# Embed angles
data.h_ang = self.embed_ang(data.x_ang, data.mask_dih_ang)
return data
class Processor(nn.Module):
"""ALIGNN Processor.
The processor updates atom, bond, and angle embeddings.
"""
def __init__(self, num_convs, dim):
super().__init__()
self.num_convs = num_convs
self.dim = dim
self.atm_bnd_convs = nn.ModuleList([copy.deepcopy(GatedGCN(dim, dim)) for _ in range(num_convs)])
self.bnd_ang_convs = nn.ModuleList([copy.deepcopy(GatedGCN(dim, dim)) for _ in range(num_convs)])
def forward(self, data):
edge_index_G = data.edge_index_G
edge_index_A = data.edge_index_A
for i in range(self.num_convs):
data.h_bnd, data.h_ang = self.bnd_ang_convs[i](data.h_bnd, edge_index_A, data.h_ang)
data.h_atm, data.h_bnd = self.atm_bnd_convs[i](data.h_atm, edge_index_G, data.h_bnd)
return data
class Decoder(nn.Module):
def __init__(self, node_dim, out_dim):
super().__init__()
self.node_dim = node_dim
self.out_dim = out_dim
self.decoder = MLP([node_dim, node_dim, out_dim], act=nn.SiLU())
def forward(self, data):
return self.decoder(data.h_atm)
class ALIGNN(nn.Module):
"""ALIGNN model.
Can optinally encode dihedral angles.
"""
def __init__(self, encoder, processor, decoder):
super().__init__()
self.encoder = encoder
self.processor = processor
self.decoder = decoder
def forward(self, data):
data = self.encoder(data)
data = self.processor(data)
return self.decoder(data)
Training file
from tqdm.notebook import trange
from datetime import datetime
import glob
import sys
import os
def train(loader,model,parameters,PIN_MEMORY=False):
model.train()
total_loss = 0.0
model = nn.DataParallel(model, device_ids=[0, 1]).cuda()
#model = model.to(parameters['device'])
optimizer = torch.optim.AdamW(model.module.processor.parameters(), lr=parameters['LEARN_RATE'])
#model = model.to(parameters['device'])
loss_fn = torch.nn.MSELoss()
for i,data in enumerate(loader, 0):
optimizer.zero_grad(set_to_none=True)
#data = data.to(parameters['device'], non_blocking=PIN_MEMORY)
data = data.cuda()
#encoding = model.encoder(data)
#proc = model.processor(encoding.module)
#atom_contrib, bond_contrib, angle_contrib = model.decoder(proc.module)
atom_contrib, bond_contrib, angle_contrib = model(data)
all_sum = atom_contrib.sum() + bond_contrib.sum() + angle_contrib.sum()
loss = loss_fn(all_sum, data.y[0][0])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
def run_training(data,parameters,model):
follow_batch = ['x_atm', 'x_bnd', 'x_ang'] if hasattr(data['training'][0], 'x_ang') else ['x_atm']
loader_train = DataLoader(data['training'], batch_size=parameters['BATCH_SIZE'], shuffle=True, follow_batch=follow_batch)
loader_valid = DataLoader(data['validation'], batch_size=parameters['BATCH_SIZE'], shuffle=False)
L_train, L_valid = [], []
min_loss_train = 1.0E30
min_loss_valid = 1.0E30
stats_file = open(os.path.join(parameters['model_dir'],'loss.data'),'w')
stats_file.write('Training_loss Validation loss\n')
stats_file.close()
for ep in range(parameters['num_epochs']):
stats_file = open(os.path.join(parameters['model_dir'], 'loss.data'), 'a')
print('Epoch ',ep,' of ',parameters['num_epochs'])
sys.stdout.flush()
loss_train = train(loader_train, model, parameters);
L_train.append(loss_train)
loss_valid = test_non_intepretable(loader_valid, model, parameters)
L_valid.append(loss_valid)
stats_file.write(str(loss_train) + ' ' + str(loss_valid) + '\n')
if loss_train < min_loss_train:
min_loss_train = loss_train
if loss_valid < min_loss_valid:
min_loss_valid = loss_valid
if parameters['remove_old_model']:
model_name = glob.glob(os.path.join(parameters['model_dir'], 'model_*'))
if len(model_name) > 0:
os.remove(model_name[0])
now = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
print('Min train loss: ', min_loss_train, ' min valid loss: ', min_loss_valid, ' time: ', now)
torch.save(model.state_dict(), os.path.join(parameters['model_dir'], 'model_' + str(now)))
stats_file.close()
if loss_train < parameters['train_tolerance'] and loss_valid < parameters['train_tolerance']:
print('Validation and training losses satisy set tolerance...exiting training loop...')
break
There are many other files but I think these are the relevant ones for this specific problem, but I'm happy to include for code if needed. My data is stored as a graph and batched by a DataLoader and looks like this (when batched two at a time):
Graph_DataBatch(atoms=[2], edge_index_G=[2, 89966], edge_index_A=[2, 1479258], x_atm=[5184, 5], x_atm_batch=[5184], x_atm_ptr=[3], x_bnd=[89966], x_bnd_batch=[89966], x_bnd_ptr=[3], x_ang=[1479258], x_ang_batch=[1479258], x_ang_ptr=[3], mask_dih_ang=[1479258], atm_amounts=[6], bnd_amounts=[6], ang_amounts=[6], y=[179932, 1])
The problematic line is in the Encoder during its forward function:
def forward(self, data):
# Embed atoms
data.h_atm = self.embed_atm(data.x_atm)
# Embed bonds
data.h_bnd = self.embed_bnd(data.x_bnd)
# Embed angles
data.h_ang = self.embed_ang(data.x_ang, data.mask_dih_ang)
return data
Specifically, data.h_atm = self.embed_atm(data.x_atm)
. To briefly explain what I'm trying to do, I am loading in a bunch of graphs into the Dataloader, which are batched and then fed into the model for training:
def train(loader,model,parameters,PIN_MEMORY=False):
model.train()
total_loss = 0.0
model = nn.DataParallel(model, device_ids=[0, 1]).cuda()
#model = model.to(parameters['device'])
optimizer = torch.optim.AdamW(model.module.processor.parameters(), lr=parameters['LEARN_RATE'])
#model = model.to(parameters['device'])
loss_fn = torch.nn.MSELoss()
for i,data in enumerate(loader, 0):
optimizer.zero_grad(set_to_none=True)
#data = data.to(parameters['device'], non_blocking=PIN_MEMORY)
data = data.cuda()
#encoding = model.encoder(data)
#proc = model.processor(encoding.module)
#atom_contrib, bond_contrib, angle_contrib = model.decoder(proc.module)
atom_contrib, bond_contrib, angle_contrib = model(data)
My understanding is that I have sent my batched data to the GPU and my model parameters are on the GPU and DataParallel should take care of splitting my data up and sending everything to each GPU automatically.
My question can be broken into a few parts: (1) Is this understanding correct? (2) Does my code actually seem like its doing this, and (3) does this error have anything to do with that, and if not, what is this error trying to tell me? I don't expect anyone to fix my code for me, but I would like to understand why this error is happening, because I think I'm misunderstanding the underlying logic of how DataParallel is taking my data and sending it to the GPU. I'm happy to provide any details you might need to better understand this problem.
I have tried to better understand the line that breaks: data.h_atm = self.embed_atm(data.x_atm)
by printing out where data.x_atm
actually is when inside of the forward function, which should be after DataParallel has partitioned the data and I get this for all tensors:
tensor([[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
...,
[0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0.]], device='cuda:0')
I think this telling me that all of my data is on GPU 0 despite the model being on both GPU 0 and 1 (when running on 2 GPUs), which I've confirmed by using nvidia-smi
and observing both GPUs at about half memory consumption. I have also tried various combinations of calling X.to_device('cuda')
or X.cuda()
, where X is my graph data, but nothing seems to make any difference to the tensor print out.
If anyone is seeing this in the future, the issue was that the graph data was never being sent to the correct GPU when the model's forward functions were called and it was only being sent to the default GPU when calling data = data.cuda() inside of the DataLoader loop. The solution was to call data = data.cuda() inside of the model's forward function, which is when DataParallel actually has access to the data, and automatically sends the data to correct GPUs.