Search code examples
pythonpytorchdistributed-computing

PyTorch: Running Inference on multiple GPUs


I have a model that accepts two inputs. I want to run inference on multiple GPUs where one of the inputs is fixed, while the other changes. So, let’s say I use n GPUs, each of them has a copy of the model. First gpu processes the input pair (a_1, b), the second processes (a_2, b) and so on. All the outputs are saved as files, so I don’t need to do a join operation on the outputs. How can I do this with DDP or otherwise?


Solution

  • I have figured out how to do this using torch.multiprocessing.Queue:

    import torch
    import torch.multiprocessing as mp
    from absl import app, flags
    from torchvision.models import AlexNet
    
    FLAGS = flags.FLAGS
    
    flags.DEFINE_integer("num_processes", 2, "Number of subprocesses to use")
    
    
    def infer(rank, queue):
        """Each subprocess will run this function on a different GPU which is indicated by the parameter `rank`."""
        model = AlexNet()
        device = torch.device(f"cuda:{rank}")
        model.to(device)
        while True:
            a, b = queue.get()
            if a is None:  # check for sentinel value
                break
            x = a + b
            x = x.to(device)
            model(x)
            del a, b  # free memory
            print(f"Inference on process {rank}")
    
    
    def main(argv):
        queue = mp.Queue()
        processes = []
        for rank in range(FLAGS.num_processes):
            p = mp.Process(target=infer, args=(rank, queue))
            p.start()
            processes.append(p)
        for _ in range(10):
            a_1 = torch.randn(1, 3, 224, 224)
            a_2 = torch.randn(1, 3, 224, 224)
            b = torch.randn(1, 3, 224, 224)
            queue.put((a_1, b))
            queue.put((a_2, b))
        for _ in range(FLAGS.num_processes):
            queue.put((None, None))  # sentinel value to signal subprocesses to exit
        for p in processes:
            p.join()  # wait for all subprocesses to finish
    
    
    if __name__ == "__main__":
        app.run(main)