I am trying to benchmark matmul operation and save the operation as a model in .pb format.
class GEMMBenchmark(tf.keras.Model):
def __init__(self, m, n, k):
super(GEMMBenchmark, self).__init__()
self.A = tf.Variable(tf.random.normal((m, k)), trainable=False)
self.B = tf.Variable(tf.random.normal((k, n)), trainable=False)
def __call__(self):
return tf.matmul(self.A, self.B)
GEMM = GEMMBenchmark(1, 1, 1)
GEMM.save("GEMM")
I get this error when I try to save the model which does not required any inputs. Seems like the forward pass is not defined.
Any workaround or alternative implementations? Thanks!
I am not sure what you are really trying to achieve, your model is too incomplete to be saved. This is one way you can add some extra to make the save function work:
class GEMMBenchmark(tf.keras.Model):
def __init__(self, m, n, k):
super(GEMMBenchmark, self).__init__()
self.A = tf.Variable(tf.random.normal((m, k)), trainable=False)
self.B = tf.Variable(tf.random.normal((k, n)), trainable=False)
def call(self, inputs=None):
return tf.matmul(self.A, self.B)
GEMM = GEMMBenchmark(1, 1, 1)
GEMM.compile(loss=tf.keras.losses.MeanSquaredError())
GEMM.fit([1], [1])
GEMM.save("GEMM")