I stumbled upon crabnet recently, and I'm using it in a catalysis project. It's been a good addition to the project, but I've found a problem.
I wanted to train crabnet in different subsets of my dataset (filtering on reaction conditions) to try to see if that gave better results (I have sometimes the same composition but with different reaction conditions, which gives different results). The problem seems to be that when you instantiate 2 instances of crabnet, some parts seem to be shared.
Here's a Minimal Working example of how to replicate the problem:
"""Basic usage of CrabNet regression on elasticity dataset."""
from crabnet.utils.data import get_data
from crabnet.data.materials_data import elasticity, example_materials_property
from crabnet.crabnet_ import CrabNet
train_df, val_df = get_data(elasticity, "train.csv", dummy=True)
train_df_2, val_df_2 = get_data(example_materials_property, "train.csv", dummy=True)
cb = CrabNet(mat_prop="elasticity")
cb.fit(train_df)
val_pred, val_sigma = cb.predict(val_df, return_uncertainty=True)
cbn = CrabNet(mat_prop="example_materials_property")
cbn.fit(train_df_2)
val_pred_2, val_sigma_2 = cbn.predict(val_df_2, return_uncertainty=True)
The example is just the Basic Usage example in Crabnet's Docs but with each step repeated (and changing the dataset just in case).
This returns this error: File ".conda\lib\site-packages\torch\optim\optimizer.py", line 271, in wrapper for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()): AttributeError: 'SWA' object has no attribute '_optimizer_step_pre_hooks'
. The whole error log refers to line 14 cbn.fit(train_df_2)
as the trigger of the problem.
How could I instantiate more than 1 model?
I've been talking to the developers of CrabNet over github, and they've located the problem. Apparently, when pytorch is V2.0, some of the methods CrabNet uses have changed, and it does not work properly.
There's an issue to fix it, but it doesn't seem to be worked on soon, so for now the only fix is to downgrade pytorch to before V2.0 so it works properly.