Search code examples
tensorflowgoogle-compute-enginecluster-computingtpu

How would you set up a tensorflow cluster of multipe TPUv2-8 (tpu-vm)?


I have two tpu-vms (v2-8) running on GCE with software version tpu-vm-tf-2.8.0. I would like to perform distributed deep learning with tensorflow using both vms, i.e with a total of 2x8 = 16 cores.

For distributed learning on 8 cores I set the strategy as follows:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)

What do I need to change in order to connect multiple vms? I feel like this will probably involve MultiWorkerMirroredStrategy but I'm not sure how. Note that I did manage to make it work on Pytorch XLA.


Solution

  • tf.distribute.TPUStrategy will work for both TPU device (v2-8, v3-8) and TPU Pod slice (v2-32, v3-32, v2-64, ...). Valid pod slice configurations don't include v2-16, but they include v4-16.

    If you want to create TPU pod slice v2-32 with TF2 image, you can use --version=tpu-vm-tf-2.9.1-pod and --accelerator-type=v2-32:

    gcloud alpha compute tpus tpu-vm create my-tpu-32 \
    --zone=europe-west4-a \
    --accelerator-type=v2-32 \
    --version=tpu-vm-tf-2.9.1-pod
    

    Note: my-tpu-32 will have 32 TPU cores. You will need to change 'local' to the TPU pod slice name (i.e. my-tpu-32).

    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='my-tpu-32')
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.experimental.TPUStrategy(resolver)
    

    You won't need to provide tpu='my-tpu-32' above if TPU_NAME envirnment variable is set to my-tpu-32. The training on TPU Pod slice will look like this (note setting TPU_LOAD_LIBRARY=0):

    TPU_NAME=my-tpu-32 TPU_LOAD_LIBRARY=0 python3 my_training.py
    
    

    The rest of the code can be the same (possibly updating batch size and learning rate, some details here).