Search code examples
pythonpytorchpytorch-geometricdataparallel

Issue with pytorch tensors and multiple GPUs when using DataParallel


I have a large ML code that I've been writing for a few months and I've started the process of try to parallelize the data side of things to work with multiple GPUs. To start, the code works perfectly when using a single GPU; the issue comes when using multiple GPUs.

The error is as follows: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Snippets of relevant code can be found below:

Model file

import torch
from torch import nn
from functools import partial
import copy

from ..mlp import MLP
from ..basis import gaussian, bessel
from ..conv import GatedGCN


class Encoder(nn.Module):
    """ALIGNN/ALIGNN-d Encoder.
    The encoder must take a PyG graph object `data` and output the same `data`
    with additional fields `h_atm`, `h_bnd`, and `h_ang` that correspond to the atom, bond, and angle embedding.

    The input `data` must have three fields `x_atm`, `x_bnd`, and `x_ang` that describe the atom type
    (in onehot vectors), the bond lengths, and bond/dihedral angles (in radians).
    """

    def __init__(self, num_species, cutoff, dim=128, dihedral=False):
        super().__init__()
        self.num_species = num_species
        self.cutoff = cutoff
        self.dim = dim
        self.dihedral = dihedral

        self.embed_atm = nn.Sequential(MLP([num_species, dim, dim], act=nn.SiLU()), nn.LayerNorm(dim))
        self.embed_bnd = partial(bessel, start=0, end=cutoff, num_basis=dim)
        self.embed_ang = self.embed_ang_with_dihedral if dihedral else self.embed_ang_without_dihedral

    def embed_ang_with_dihedral(self, x_ang, mask_dih_ang):
        cos_ang = torch.cos(x_ang)
        sin_ang = torch.sin(x_ang)

        h_ang = torch.zeros([len(x_ang), self.dim], device=x_ang.device)
        h_ang[~mask_dih_ang, :self.dim // 2] = gaussian(cos_ang[~mask_dih_ang], start=-1, end=1,
                                                        num_basis=self.dim // 2)

        h_cos_ang = gaussian(cos_ang[mask_dih_ang], start=-1, end=1, num_basis=self.dim // 4)
        h_sin_ang = gaussian(sin_ang[mask_dih_ang], start=-1, end=1, num_basis=self.dim // 4)
        h_ang[mask_dih_ang, self.dim // 2:] = torch.cat([h_cos_ang, h_sin_ang], dim=-1)

        return h_ang

    def embed_ang_without_dihedral(self, x_ang, mask_dih_ang):
        cos_ang = torch.cos(x_ang)
        return gaussian(cos_ang, start=-1, end=1, num_basis=self.dim)

    def forward(self, data):
        # Embed atoms
        data.h_atm = self.embed_atm(data.x_atm)

        # Embed bonds
        data.h_bnd = self.embed_bnd(data.x_bnd)

        # Embed angles
        data.h_ang = self.embed_ang(data.x_ang, data.mask_dih_ang)

        return data


class Processor(nn.Module):
    """ALIGNN Processor.
    The processor updates atom, bond, and angle embeddings.
    """

    def __init__(self, num_convs, dim):
        super().__init__()
        self.num_convs = num_convs
        self.dim = dim

        self.atm_bnd_convs = nn.ModuleList([copy.deepcopy(GatedGCN(dim, dim)) for _ in range(num_convs)])
        self.bnd_ang_convs = nn.ModuleList([copy.deepcopy(GatedGCN(dim, dim)) for _ in range(num_convs)])

    def forward(self, data):
        edge_index_G = data.edge_index_G
        edge_index_A = data.edge_index_A

        for i in range(self.num_convs):
            data.h_bnd, data.h_ang = self.bnd_ang_convs[i](data.h_bnd, edge_index_A, data.h_ang)
            data.h_atm, data.h_bnd = self.atm_bnd_convs[i](data.h_atm, edge_index_G, data.h_bnd)

        return data


class Decoder(nn.Module):
    def __init__(self, node_dim, out_dim):
        super().__init__()
        self.node_dim = node_dim
        self.out_dim = out_dim
        self.decoder = MLP([node_dim, node_dim, out_dim], act=nn.SiLU())

    def forward(self, data):
        return self.decoder(data.h_atm)


class ALIGNN(nn.Module):
    """ALIGNN model.
    Can optinally encode dihedral angles.
    """

    def __init__(self, encoder, processor, decoder):
        super().__init__()
        self.encoder = encoder
        self.processor = processor
        self.decoder = decoder

    def forward(self, data):
        data = self.encoder(data)
        data = self.processor(data)
        return self.decoder(data)

Training file

from tqdm.notebook import trange
from datetime import datetime
import glob
import sys
import os

def train(loader,model,parameters,PIN_MEMORY=False):
    model.train()
    total_loss = 0.0
    model = nn.DataParallel(model, device_ids=[0, 1]).cuda()
    #model = model.to(parameters['device'])
    optimizer = torch.optim.AdamW(model.module.processor.parameters(), lr=parameters['LEARN_RATE'])

    #model = model.to(parameters['device'])
    loss_fn = torch.nn.MSELoss()
    for i,data in enumerate(loader, 0):
        optimizer.zero_grad(set_to_none=True)
        #data = data.to(parameters['device'], non_blocking=PIN_MEMORY)
        data = data.cuda()
        #encoding = model.encoder(data)
        #proc = model.processor(encoding.module)
        #atom_contrib, bond_contrib, angle_contrib = model.decoder(proc.module)
        atom_contrib, bond_contrib, angle_contrib = model(data)

        all_sum = atom_contrib.sum() + bond_contrib.sum() + angle_contrib.sum()

        loss = loss_fn(all_sum, data.y[0][0])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def run_training(data,parameters,model):
    follow_batch = ['x_atm', 'x_bnd', 'x_ang'] if hasattr(data['training'][0], 'x_ang') else ['x_atm']
    loader_train = DataLoader(data['training'], batch_size=parameters['BATCH_SIZE'], shuffle=True, follow_batch=follow_batch)
    loader_valid = DataLoader(data['validation'], batch_size=parameters['BATCH_SIZE'], shuffle=False)

    L_train, L_valid = [], []
    min_loss_train = 1.0E30
    min_loss_valid = 1.0E30

    stats_file = open(os.path.join(parameters['model_dir'],'loss.data'),'w')
    stats_file.write('Training_loss     Validation loss\n')
    stats_file.close()
    for ep in range(parameters['num_epochs']):
        stats_file = open(os.path.join(parameters['model_dir'], 'loss.data'), 'a')
        print('Epoch ',ep,' of ',parameters['num_epochs'])
        sys.stdout.flush()
        loss_train = train(loader_train, model, parameters);
        L_train.append(loss_train)
        loss_valid = test_non_intepretable(loader_valid, model, parameters)
        L_valid.append(loss_valid)
        stats_file.write(str(loss_train) + '     ' + str(loss_valid) + '\n')
        if loss_train < min_loss_train:
            min_loss_train = loss_train
            if loss_valid < min_loss_valid:
                min_loss_valid = loss_valid
                if parameters['remove_old_model']:
                    model_name = glob.glob(os.path.join(parameters['model_dir'], 'model_*'))
                    if len(model_name) > 0:
                        os.remove(model_name[0])
                now = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
                print('Min train loss: ', min_loss_train, ' min valid loss: ', min_loss_valid, ' time: ', now)
                torch.save(model.state_dict(), os.path.join(parameters['model_dir'], 'model_' + str(now)))
        stats_file.close()
        if loss_train < parameters['train_tolerance'] and loss_valid < parameters['train_tolerance']:
            print('Validation and training losses satisy set tolerance...exiting training loop...')
            break

There are many other files but I think these are the relevant ones for this specific problem, but I'm happy to include for code if needed. My data is stored as a graph and batched by a DataLoader and looks like this (when batched two at a time):

Graph_DataBatch(atoms=[2], edge_index_G=[2, 89966], edge_index_A=[2, 1479258], x_atm=[5184, 5], x_atm_batch=[5184], x_atm_ptr=[3], x_bnd=[89966], x_bnd_batch=[89966], x_bnd_ptr=[3], x_ang=[1479258], x_ang_batch=[1479258], x_ang_ptr=[3], mask_dih_ang=[1479258], atm_amounts=[6], bnd_amounts=[6], ang_amounts=[6], y=[179932, 1])

The problematic line is in the Encoder during its forward function:

def forward(self, data):
        # Embed atoms
        data.h_atm = self.embed_atm(data.x_atm)

        # Embed bonds
        data.h_bnd = self.embed_bnd(data.x_bnd)

        # Embed angles
        data.h_ang = self.embed_ang(data.x_ang, data.mask_dih_ang)

        return data

Specifically, data.h_atm = self.embed_atm(data.x_atm). To briefly explain what I'm trying to do, I am loading in a bunch of graphs into the Dataloader, which are batched and then fed into the model for training:

def train(loader,model,parameters,PIN_MEMORY=False):
    model.train()
    total_loss = 0.0
    model = nn.DataParallel(model, device_ids=[0, 1]).cuda()
    #model = model.to(parameters['device'])
    optimizer = torch.optim.AdamW(model.module.processor.parameters(), lr=parameters['LEARN_RATE'])

    #model = model.to(parameters['device'])
    loss_fn = torch.nn.MSELoss()
    for i,data in enumerate(loader, 0):
        optimizer.zero_grad(set_to_none=True)
        #data = data.to(parameters['device'], non_blocking=PIN_MEMORY)
        data = data.cuda()
        #encoding = model.encoder(data)
        #proc = model.processor(encoding.module)
        #atom_contrib, bond_contrib, angle_contrib = model.decoder(proc.module)
        atom_contrib, bond_contrib, angle_contrib = model(data)

My understanding is that I have sent my batched data to the GPU and my model parameters are on the GPU and DataParallel should take care of splitting my data up and sending everything to each GPU automatically.

My question can be broken into a few parts: (1) Is this understanding correct? (2) Does my code actually seem like its doing this, and (3) does this error have anything to do with that, and if not, what is this error trying to tell me? I don't expect anyone to fix my code for me, but I would like to understand why this error is happening, because I think I'm misunderstanding the underlying logic of how DataParallel is taking my data and sending it to the GPU. I'm happy to provide any details you might need to better understand this problem.

I have tried to better understand the line that breaks: data.h_atm = self.embed_atm(data.x_atm) by printing out where data.x_atm actually is when inside of the forward function, which should be after DataParallel has partitioned the data and I get this for all tensors:

tensor([[1., 0., 0., 0., 0.], [1., 0., 0., 0., 0.], [1., 0., 0., 0., 0.], ..., [0., 0., 1., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 1., 0., 0.]], device='cuda:0')

I think this telling me that all of my data is on GPU 0 despite the model being on both GPU 0 and 1 (when running on 2 GPUs), which I've confirmed by using nvidia-smi and observing both GPUs at about half memory consumption. I have also tried various combinations of calling X.to_device('cuda') or X.cuda(), where X is my graph data, but nothing seems to make any difference to the tensor print out.


Solution

  • If anyone is seeing this in the future, the issue was that the graph data was never being sent to the correct GPU when the model's forward functions were called and it was only being sent to the default GPU when calling data = data.cuda() inside of the DataLoader loop. The solution was to call data = data.cuda() inside of the model's forward function, which is when DataParallel actually has access to the data, and automatically sends the data to correct GPUs.