Search code examples
pythontensorflowmachine-learningdeep-learningunet-neural-network

model.parameters() alternative for TransUNet from transunet python library


I am trying to implement TransUNet for breakhis dataset I was making the optimizer like this

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

my model is

from transunet import TransUNet
model = TransUNet(image_size=224, pretrain=True)

But the parameter() function does not work with TransUNet

This is the library I am using https://github.com/awsaf49/TransUNet-tf

I tried using named_parameter(), __dict__['parameters'],state_dict() but none of them work.


Solution

  • The TransUNet model from transunet Library is implemented in TensorFlow and .parameters() is a torch function.

    For TensorFlow we can use something like this -

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))