Search code examples
pythonartificial-intelligencehuggingface-transformerspytorch-lightning

Pytorch Lightning places model inputs and model to different devices


I'm using Pytorch-lightning 2.4.0. In the following code snippet, lmm is a class inherited from nn.Module which is a wrapper class huggingface model and processor.

class ICVModel(pl.LightningModule):
    def __init__(self, lmm, icv_encoder: torch.nn.Module) -> None:
        super().__init__()
        self.lmm = lmm
        self.lmm.requires_grad_(False)
        self.icv_encoder = icv_encoder
        self.eos_token = self.lmm.processor.tokenizer.eos_token

    def forward(self, ice_texts, query_texts, answers, images):
        query_answer = [
            query + answer + self.eos_token
            for query, answer in zip(query_texts, answers)
        ]
        query_images = [img[-setting.num_image_in_query :] for img in images]
        query_inputs = self.lmm.process_input(query_answer, query_images)
        query_outputs = self.lmm.model(
            **query_inputs,
            labels=query_inputs["input_ids"],
        )

However, a device mismatch error raised at

query_outputs = self.lmm.model(
        **query_inputs,
        labels=query_inputs["input_ids"],
)

I printed device of inputs.pixel_values.device, self.device, self.lmm.device outside of lmm.model.forward, then I got

rank[0]: cpu cuda:0 cuda:0
rank[1]: cpu cuda:1 cuda:1

In Idefics (self.lmm.model) forward process, when I printed inputs.pixel_values.device and self.device, I got

rank[0]: cuda:0 cuda:0
rank[1]: cuda:0 cuda:1

Besides, I also tried to move pixel_values to correct device, but it still be moved to wrong device in later forward pass.

Error message:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument weight in method wrapper_CUDA__cudnn_convolution)

Solution

  • I've solved this problem.

    The key to the problem I did not show in the question I asked, because at that time I did not realize that bitsandbytes and accelerate library would automatically register pre_forward_hook.

    It registered an AlignDeviceHook (maybe) at each forward method, which conflicts with the device control of pytorch lightning. When I removed bitsandbytes, everything worked fine.