I am using Dataflow on GPC using the latest version apache-beam-with-gcp=2.44.0
It is custom model class with Pytorch for my ML model. Model need to be loaded in the following way:
input_model = path / "model.ckpt"
checkpoint = torch.load(input_model, map_location=torch.device("cpu"))
model_name = checkpoint["parameters"]["model_name"]
n_classes = len(checkpoint["parameters"]["class_names"])
backbone_fn = model_architecture.get_backbone(model_name)
backbone = backbone_fn(num_classes=n_classes)
model = CustomModel.load_from_checkpoint(
checkpoint_path=input_model, backbone=backbone
)
I am trying to use the recent RunInference doc :
with pipeline as p:
(
p
| "ReadInputData" >> beam.Create(value_to_predict)
| "RunInferenceTorch" >> RunInference(torch_model_handler)
| beam.Map(print)
)
I tried to use the PytorchModelHandlerTensor in a way that works with my custom model but doesn't seems to work.
torch_model_handler = PytorchModelHandlerTensor(
state_dict_path=None,
model_class=CustomModel.load_from_checkpoint,
model_params={"backbone": backbone,
"checkpoint_path": input_model,
})
Could it be that my custom model doesn't fit the requirement to use PytorchModelHandlerTensor ? Or did I missed something obvious ?
Edit: I found in apache_beam/ml/inference/pytorch_inference.py how to is done:
model.load_state_dict(state_dict)
model.to(device)
model.eval()
which is not compatible with my model's class
I will try the other way for custom model from non supported framework like Spacy.
Solution (with a good format):
I need to load the load first save it and it works:
torch.save(model.state_dict(), './model_torch.pt')
torch_model_handler = PytorchModelHandlerTensor(
state_dict_path='./model_torch.pt',
model_class=CustomModel,
model_params={"backbone": backbone,
"criterion": criterion,
"class_names": class_names,
"expected_input_size": expected_input_size
})
Right now, Beam's Pytorch model handlers only accept a state dict representation, it looks like you're using a checkpoint. Checkpoints aren't currently accepted because there can be problems loading checkpointed models with different versions of Pytorch and they tend to be a less stable representation than a class + state_dict representation. state_dicts are also Pytorch's recommended format for saved models - https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-state-dict-recommended
To convert your model from a checkpoint to a state_dict, you should be able to do something like:
model = torch.load("path/to/checkpoint.ckpt")
torch.save(model.state_dict(), "path/to/state_dict.pth")