The errors I get are:
ZipFile requires mode 'r', 'w', 'x', or 'a'
or
ZipFile.__init__() got multiple values for argument 'mode'
or
TypeError: 'ZipFile' object is not callable
My Code:
@tf.function
def train_step(batch):
# Record all of our operations
with tf.GradientTape() as tape:
# Get anchor and positive/negative image
X = batch[:2]
# Get label
y = batch[2]
# Forward pass
yhat = siamese_model(X, training=True)
# Calculate loss
loss = binary_cross_loss(y, yhat)
print(loss)
# Calculate gradients
grad = tape.gradient(loss, [siamese_model.trainable_variables])
# Calculate updated weights and apply to siamese model
opt.apply_gradients(zipfile.ZipFile(grad, [siamese_model.trainable_variables], mode
= "w"))
# Return loss
return loss
def train(data, EPOCHS):
# Loop through epochs
for epoch in range(1, EPOCHS+1):
print('\n Epoch {}/{}'.format(epoch, EPOCHS))
progbar = tf.keras.utils.Progbar(len(data))
# Creating a metric object
r = Recall()
p = Precision()
# Loop through each batch
for idx, batch in enumerate(data):
# Run train step here
loss = train_step(batch)
yhat = siamese_model.predict(batch[:2])
r.update_state(batch[2], yhat)
p.update_state(batch[2], yhat)
progbar.update(idx+1)
print(loss.numpy(), r.result().numpy(), p.result().numpy())
# Save checkpoints
if epoch % 10 == 0:
checkpoint.save(file_prefix=checkpoint_prefix)
EPOCHS = 50
train(train_data, EPOCHS)
Errors show that the problem is coming from the train_step(batch) function when train invokes it.
I have also tried without mode = "w"
, using just 'zip' when the brackets open and without the brackets for siamese_model.trainable_variables
and I have tried
import zipfile
or import zip file from ZipFile
but it did not make a difference.
It seems like you're confusing the built-in zip function with the zipfile module. The first is to combine multiple iterables and the second creates and reads compressed archives.
You probably want to use this:
# Calculate updated weights and apply to siamese model
opt.apply_gradients(zip(grad, [siamese_model.trainable_variables]))