Search code examples
pythonmachine-learningtensorflowdeep-learningtf-slim

TensorFlow: What is the purpose of endpoints for data parallelism when training across multiple machines?


In the TensorFlow-slim source code, there was an endpoint indicated in the creation of its loss function:

def clone_fn(batch_queue):
  """Allows data parallelism by creating multiple clones of network_fn."""
  images, labels = batch_queue.dequeue()
  logits, end_points = network_fn(images)

  #############################
  # Specify the loss function #
  #############################
  if 'AuxLogits' in end_points:
    slim.losses.softmax_cross_entropy(
        end_points['AuxLogits'], labels,
        label_smoothing=FLAGS.label_smoothing, weight=0.4, scope='aux_loss')
  slim.losses.softmax_cross_entropy(
      logits, labels, label_smoothing=FLAGS.label_smoothing, weight=1.0)
  return end_points

Source: https://github.com/tensorflow/models/blob/master/slim/train_image_classifier.py#L471-L477

My idea is that there are multiple identical networks that are trained in separate machines, and the variables and parameters are averaged out in the end to merge into one network (is this correct?). But I don't quite get what the purpose of endpoints are in this case, as I thought the network_fn ought to produce only logits for predictions. What's the use of end_points?


Solution

  • The endpoints in this case just track the different outputs of the model. The AuxLogits one, for example, has the logits.