Search code examples
pytorchgoogle-cloud-dataflowapache-beam

How to use RunInference with Beam and a custom pytorch class/model?


I am using Dataflow on GPC using the latest version apache-beam-with-gcp=2.44.0 It is custom model class with Pytorch for my ML model. Model need to be loaded in the following way:

input_model = path / "model.ckpt"
checkpoint = torch.load(input_model, map_location=torch.device("cpu"))
model_name = checkpoint["parameters"]["model_name"]
n_classes = len(checkpoint["parameters"]["class_names"])
backbone_fn = model_architecture.get_backbone(model_name)
backbone = backbone_fn(num_classes=n_classes)

model = CustomModel.load_from_checkpoint(
    checkpoint_path=input_model, backbone=backbone
)

I am trying to use the recent RunInference doc :

with pipeline as p:
      (
      p
      | "ReadInputData" >> beam.Create(value_to_predict)
      | "RunInferenceTorch" >> RunInference(torch_model_handler)
      | beam.Map(print)
      )    

I tried to use the PytorchModelHandlerTensor in a way that works with my custom model but doesn't seems to work.

torch_model_handler = PytorchModelHandlerTensor(
    state_dict_path=None, 
    model_class=CustomModel.load_from_checkpoint,
    model_params={"backbone": backbone,
                  "checkpoint_path": input_model,
                  })

Could it be that my custom model doesn't fit the requirement to use PytorchModelHandlerTensor ? Or did I missed something obvious ?

Edit: I found in apache_beam/ml/inference/pytorch_inference.py how to is done:

model.load_state_dict(state_dict)
model.to(device)
model.eval()

which is not compatible with my model's class

I will try the other way for custom model from non supported framework like Spacy.

Solution (with a good format):

I need to load the load first save it and it works:

torch.save(model.state_dict(), './model_torch.pt')

torch_model_handler = PytorchModelHandlerTensor(
    state_dict_path='./model_torch.pt',
    model_class=CustomModel,
    model_params={"backbone": backbone,
                  "criterion": criterion,
                  "class_names": class_names,
                  "expected_input_size": expected_input_size
                  })

Solution

  • Right now, Beam's Pytorch model handlers only accept a state dict representation, it looks like you're using a checkpoint. Checkpoints aren't currently accepted because there can be problems loading checkpointed models with different versions of Pytorch and they tend to be a less stable representation than a class + state_dict representation. state_dicts are also Pytorch's recommended format for saved models - https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-state-dict-recommended

    To convert your model from a checkpoint to a state_dict, you should be able to do something like:

    model = torch.load("path/to/checkpoint.ckpt")
    torch.save(model.state_dict(), "path/to/state_dict.pth")