Search code examples
pythontensorflowkerastensorboard

How to visualize keras model that has custom model sub-classes with TensorBoard?


I have a model that is composed of several sub-models that inherit from tf.keras.Model. These sub-models are all more-or-less simply sets of keras.Sequential models that compose keras.layers such as keras.layers.Conv2D, keras.layers.BatchNormalization, etc. And the call function passes the data through the different sequential models (sometimes adding extra stuff such as the original input to the output of the sequential model, a la a ResidualBlock sub-model).

The reason my main model is comprised of sub-models is because the main model is complex and doing this allows for me to change the model architecture (such as number of layers of Sub-model A easily. In addition, parts of the sub-models instantiate certain layers (such as keras.layers.Reshape) in thecallfunction because the argument to configure theReshape` depends on the input to the call function.

The model compiles successfully, and I have passed random data through it (have not yet trained it), but I want to visualize it.

I tried to do the following

tensorboard = TensorBoard(log_dir='./logs/{}'.format(time()))
tensorboard.set_model(model)

but I get a Warning:
WARNING:tensorflow:Model failed to serialize as JSON. Ignoring...

Nor can I save it with
model.save('path_to_file.h5')
because I get a 'NotImplemnedError.

After doing research I see that the recommended way to save custom models is to save just the weights and load just the weights.

How can I visualize my model with Tensorboard? Do i need to implement a serializer? Is there a guide for this?


Solution

  • As far as using Keras API's Tensorboard, you're likely out of luck, as it wasn't designed to operate on nested models w/ custom functionalities. Good news is, the linked source code isn't hard to understand, and in fact was much more intuitive for me than the official guides - so you should be able to write your own tensorboard class to meet your needs.

    Any 'workarounds' in form of nested callbacks are going to be lot buggier and harder to manage than a custom class - so while latter may involve more work initially, it should pay off in the long run.

    Lastly, the Tensorboard API has limited customizability - e.g. can't select specific layers to log, or which metrics to omit. To this end, I wrote my own class - see excerpt below; it doesn't support nested models, but can be easily expanded to do so.

    def run(self, log_num=None):
        tensors = (self.model.inputs + 
                   self.model.targets + 
                   self.model.sample_weights)
        assert len(tensors) == len(self.val_data)
    
        feed_dict = dict(zip(tensors, self.val_data))
        summary = self.sess.run([self.merged], feed_dict=feed_dict)
    
        log_num = log_num or self.log_num
        self.writer.add_summary(summary[0], log_num)
        self.log_num += 1
    
        if self.verbose:
            print("MyTensorBoard saved %s logs" % len(summary))
    
    def _init_logger(self):
        for layer in self.model.layers:
            if any([(spec_name in layer.name) for spec_name in self.layer_names]):
                grads = self._get_grads(layer)
                if grads is not None:
                    tf.summary.histogram(layer.name + '_grad', grads)
                if hasattr(layer, 'output'):
                    self._log_outputs(layer)
    
                for weight in layer.weights:
                    mapped_weight_name = weight.name.replace(':', '_')
                    tf.summary.histogram(mapped_weight_name, weight)
    
                    w_img = self._to_img_format(weight)
                    if w_img is not None:
                        tf.summary.image(mapped_weight_name, w_img)
        self.merged = tf.summary.merge_all()
        self._init_writer()
        print("MyTensorBoard initialized")
    
    
    def _init_writer(self):
        tb_num = 0
        while any([('TB_' + str(tb_num) in fname) for fname in 
                   os.listdir(self.base_logdir)]):
            tb_num += 1
        self.logdir = os.path.join(self.base_logdir, 'TB_%s' % tb_num)
        os.mkdir(self.logdir)
        print("New TB logdir created at %s" % self.logdir)
    
        if self.write_graph:
            self.writer = tf.summary.FileWriter(self.logdir, self.sess.graph)
        else:
            self.writer = tf.summary.FileWriter(self.logdir)
    
    def _get_grads(self, layer):
        for weight_tensor in layer.trainable_weights:
            grads = self.model.optimizer.get_gradients(
                         self.model.total_loss, weight_tensor)
    
            is_indexed_slices = lambda g: type(g).__name__ == 'IndexedSlices'
            return [grad.values if is_indexed_slices(grad) 
                                else grad for grad in grads]
    
    def _log_outputs(self, layer):
        if isinstance(layer.output, list):
            for i, output in enumerate(layer.output):
                tf.summary.histogram('{}_out_{}'.format(layer.name, i), output)
        else:
            tf.summary.histogram('{}_out'.format(layer.name), layer.output)