Search code examples
pythontensorflowgoogle-compute-enginegoogle-cloud-tpubfloat16

Memory reduction Tensorflow TPU v2/v3 bfloat16


My model is too big to get a batch >64 with the normal v2 TPU devices. On the troubleshooting site it is mentioned that upcoming tensorflow versions will have bfloat16 support. Are the newly supported tf versions 1.9-1.12 capable to use bfloat16 now and if yes, is there a limited set of optimizers I can use? I did not find any further documentation on this but saw the usage of bfloat16 in the tensor2tensor model, so I guess there must be a way.

Furthermore I read that TPU v3 supports bigger models as well but that the model would need minimal changes, but I don't find any documentation what needs to be changed.

I'm already using Adafactor and tried to reduce my layers, if you have any further reduction tips, that would be great too. I'm using picture matrices and word vectors (float32 as of now) as input.


Solution

  • You can use bfloat16 with TPUs. There are two main things to do:

    1. Cast the input to bfloat16 in your input pipeline
    2. Surround your network within a bfloat16 scope and cast the outputs as F32 for further calculations.

    Here is a code snippet that illustrates the necessary changes:

    def input_fn():
    
      def dataset_parser(self, value):
        """Parse an ImageNet record from a serialized string Tensor."""
        image = self.image_preprocessing_fn(
            image_bytes=image_bytes,
            is_training=self.is_training,
        )
    
        if self.use_bfloat16:
          image = tf.cast(image, tf.bfloat16)
    
        return image, label
    
    
    def resnet_model_fn(features, labels, mode, params):
      """The model_fn for ResNet to be used with TPUEstimator."""
    
      # This nested function allows us to avoid duplicating the logic which
      # builds the network, for different values of --precision.
      def build_network():
        network = resnet_model.resnet_v1(
            resnet_depth=FLAGS.resnet_depth,
            num_classes=LABEL_CLASSES,
            data_format=FLAGS.data_format)
        return network(
            inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
    
      if FLAGS.precision == 'bfloat16':
        with bfloat16.bfloat16_scope():
          logits = build_network()
        logits = tf.cast(logits, tf.float32)
      elif FLAGS.precision == 'float32':
        logits = build_network()
    

    You can also see the second condition illustrated in this TPU model.