Search code examples
pythontensorflowkerascleverhans

Cant Use utils_keras.Sequential still thinks its not Cleverhans model


I'm trying to do Saliency Map Method using cleverhans.

My model needs to be keras sequential so for that reason I've searched and found cleverhans.utils_keras, Sequential uses KerasModelWrapper. But for some reason I still get it should be cleverhans model. Here's the stacktrace

TypeError Traceback (most recent call last) in 2 # https://github.com/tensorflow/cleverhans/blob/master/cleverhans/utils_keras.py 3 ----> 4 jsma = SaliencyMapMethod(model, sess=sess) 5 jsma_params = {'theta': 10.0, 'gamma': 0.15, 6 'clip_min': 0., 'clip_max': 1.,

c:\users\jeredriq\appdata\local\programs\python\python35\lib\site-packages\cleverhans\attacks__init__.py in init(self, model, sess, dtypestr, **kwargs) 911 """ 912 --> 913 super(SaliencyMapMethod, self).init(model, sess, dtypestr, **kwargs) 914 915 self.feedable_kwargs = ('y_target',)

c:\users\jeredriq\appdata\local\programs\python\python35\lib\site-packages\cleverhans\attacks__init__.py in init(self, model, sess, dtypestr, **kwargs) 55 56 if not isinstance(model, Model): ---> 57 raise TypeError("The model argument should be an instance of" 58 " the cleverhans.model.Model class.") 59

TypeError: The model argument should be an instance of the cleverhans.model.Model class.

And here's my code


import numpy as np
from keras import backend
import tensorflow as tf
from keras.callbacks import ModelCheckpoint
from matplotlib import gridspec
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from keras.datasets import mnist
from keras.layers import Dense, Dropout
from keras.layers import Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.utils import np_utils
from cleverhans.attacks import FastGradientMethod
from cleverhans.attacks import BasicIterativeMethod
from cleverhans.attacks import SaliencyMapMethod
from cleverhans.attacks import DeepFool

from cleverhans.utils_keras import Sequential


sess =  backend.get_session()
x = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
y = tf.placeholder(tf.float32, shape=(None, 10))
# Managing Mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train/=255
X_test/=255
y_train_cat = np_utils.to_categorical(y_train)
y_test_cat = np_utils.to_categorical(y_test)
num_classes = y_test_cat.shape[1]

### Defining Model ###

model = Sequential()      #  <-----  I use Sequential from CleverHans

model.add(Conv2D(32, (5, 5), input_shape=(28,28,1), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

history = model.fit(X_train, y_train_cat, epochs=10, batch_size=1024, verbose=1, validation_split=0.7)


### And the problem part ###

jsma = SaliencyMapMethod(model, sess=sess)  # <---- Where I get the exception

jsma_params = {'theta': 10.0, 'gamma': 0.15,
                   'clip_min': 0., 'clip_max': 1.,
                   'y_target': None}

sample_size = 20
one_hot_target = np.zeros((sample_size, 10), dtype=np.float32)
one_hot_target[:, 1] = 1
jsma_params['y_target'] = one_hot_target

X_test_small = X_test[0:sample_size,:]
y_test_small = y_test[0:sample_size]

adv_x = jsma.generate_np(X_test_small, **jsma_params)

I've the same question on github too.


Solution

  • The Sequential defined in cleverhans.utils_keras is still keras' Sequential model. What is needed is cleverhans.model.Model. A keras model can be wrapped to provide this behaviour by using the KerasModelWrapper class.

    Replace

    jsma = SaliencyMapMethod(model, sess=sess)
    

    with

    jsma = SaliencyMapMethod(KerasModelWrapper(model), sess=sess)