I am trying to optimize my filter activation using a pretrained model (vgg16) and reduce mean for filter score calculation. I am constantly getting an error that "No gradient provided for any variable".
I would really appreciate any help. Thanks!
Here you can see the code:
import numpy as np
import tensorflow as tf
from tensorflow import keras
np.random.seed(1)
image_f = np.random.normal(size=[1, 32, 32, 3], scale=0.01).astype(np.float32)
img = tf.nn.sigmoid(image_f)
tf.compat.v1.keras.backend.set_image_data_format('channels_last')
model = keras.applications.VGG16(weights="imagenet", include_top=False)
optimizer = tf.keras.optimizers.Adam(epsilon=1e-08, learning_rate=0.05)
layer_weight =keras.Model(inputs=model.inputs, outputs=model.get_layer(name="block3_conv1").output)
for i in range(5):
img = tf.Variable(img)
filter_activation = layer_weight(img)[:,:,:,5]
def compute_activation():
score = -1 * tf.reduce_mean(filter_activation)
print(score)
return score
optimizer.minimize(compute_activation, [img])
print(img)
I think the problem is your variable img
is not included in the calculation of your loss function. I modified your code according to the documentation: https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer.
import numpy as np
import tensorflow as tf
from tensorflow import keras
np.random.seed(1)
image_f = np.random.normal(size=[1, 32, 32, 3], scale=0.01).astype(np.float32)
img = tf.nn.sigmoid(image_f)
tf.compat.v1.keras.backend.set_image_data_format('channels_last')
model = keras.applications.VGG16(weights="imagenet", include_top=False)
optimizer = tf.keras.optimizers.Adam(epsilon=1e-08, learning_rate=0.05)
layer_weight =keras.Model(inputs=model.inputs, outputs=model.get_layer(name="block3_conv1").output)
# Variable only need to define once
img = tf.Variable(img)
def compute_activation():
# Include variable img here
filter_activation = layer_weight(img)[:,:,:,5]
score = -1 * tf.reduce_mean(filter_activation)
print(score)
return score
for i in range(5):
optimizer.minimize(compute_activation, [img])
print(img)