I want to generate attack samples via the following steps:
Find a pre-trained CNN classification model, whose input is X and output is P(y|X), and the most possible result of X is y.
I want to input X' and get y_fool, where X' is not far away from X and y_fool is not equal to y
The steps for getting X' is:enter image description here
How can I get the partial derivative described in the image?
Here is my code but I got None: (The model is Vgg16)
x = torch.autograd.Variable(image, requires_grad=True)
output = model(image)
prob = nn.functional.softmax(output[0], dim=0)
prob.backward(torch.ones(prob.size()))
print(x.grad)
How should I modify my codes? Could someone help me? I would be absolutely grateful.
Here, the point is to backpropagate a "false" example through the network, in other words you need to maximize one particular coordinate of your output which does not correspond to the actual label of x
.
Let's say for example that your model outputs N
-dimensional vectors, that x
label should be [1, 0, 0, ...]
and that we will try to make the model actually predict [0, 1, 0, 0, ...]
(so y_fool
actually has its second coordinate set to 1, instead of the first one).
Quick note on the side : Variable
is deprecated, just set the requires_grad
flag to True
. So you get :
x = torch.tensor(image, requires_grad=True)
output = model(x)
# If the model is well trained, prob_vector[1] should be almost 0 at the beginning
prob_vector = nn.functional.softmax(output, dim=0)
# We want to fool the model and maximize this coordinate instead of prob_vector[0]
fool_prob = prob_vector[1]
# fool_prob is a scalar tensor, so we can backward it easy
fool_prob.backward()
# and you should have your gradients :
print(x.grad)
After that, if you want to use an optimizer
in your loop to modify x
, remember that pytorch optimizer.step
method tries to minimize the loss, whereas you want to maximize it. So either you use a negative learning rate or you change the backprop sign :
# Maximizing a scalar is minimizing its opposite
(-fool_prob).backward()