I am trying to implement TransUNet for breakhis dataset I was making the optimizer like this
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
my model is
from transunet import TransUNet
model = TransUNet(image_size=224, pretrain=True)
But the parameter()
function does not work with TransUNet
This is the library I am using https://github.com/awsaf49/TransUNet-tf
I tried using named_parameter()
, __dict__['parameters']
,state_dict()
but none of them work.
The TransUNet
model from transunet
Library is implemented in TensorFlow and .parameters()
is a torch function.
For TensorFlow we can use something like this -
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
optimizer.apply_gradients(zip(grads, model.trainable_variables))