Search code examples
pytorch

Best practice to pass PyTorch device name to model


Currently, I separated train.py with model.py for my deep learning project.

So for the datasets, they are sent to cuda device inside the epoch for loop like below.

train.py

...
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = MyNet(~).to(device)
...
for batch_data in train_loader:
    s0 = batch_data[0].to(device)
    s1 = batch_data[1].to(device)
    pred = model(s0, s1)

However, inside my model (in model.py), it also needs to access the device variable for skip connection like method. To make a new copy of hidden unit (for residual connection)

model.py

class MyNet(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super(MyNet, self).__init__()
        self.conv1 = GCNConv(in_feats, hid_feats)
        ...

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x1 = copy.copy(x.float())
        x = self.conv1(x, edge_index)
        skip_conn = torch.zeros(len(data.batch), x1.size(1)).to(device)  # <--
        (some opps for x1 -> skip_conn)
        x = torch.cat((x, skip_conn), 1)

In this case, I am currently passing device as a parameter, however, I believe this is not a best practice.

  1. Where should be the best practice to send the dataset to CUDA?
  2. In the case of multiple scripts need to access device, how sould I handle this? (parameter, global variable?)

Solution

  • You can add a new attribute to MyModel to store the device info and use this in the skip_conn initialization.

    class MyNet(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, device): # <--
        super(MyNet, self).__init__()
        self.conv1 = GCNConv(in_feats, hid_feats)
        self.device = device # <--
        self.to(self.device) # <--
        ...
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x1 = copy.copy(x.float())
        x = self.conv1(x, edge_index)
        skip_conn = torch.zeros(len(data.batch), x1.size(1), device=self.device)  # <--
        (some opps for x1 -> skip_conn)
        x = torch.cat((x, skip_conn), 1)
    

    Notice that in this example, MyNet is responsible for all the device logic including the .to(device) call. This way, we are encapsulating all model-related device management in the model class itself.