Search code examples
socketstensorflowdeep-learningdistributed-computinggradient-descent

Where the weights get updated in this code?


I want to train a model in distributed system. I have found a code in github for distributed training where the worker node send gradient to the parameter server and the parameter server sends the average gradient to the workers. But in client/worker side code, i couldn't understand where the received gradient updates the weights and biases.

Here is client/worker side the code, it receives initial gradients from the parameter server and then calculates loss, gradients and sends the gradient value to the server again.

from __future__ import division
from __future__ import print_function

import numpy as np
import sys
import pickle as pickle
import socket

from datetime import datetime
import time

import tensorflow as tf

import cifar10

TCP_IP = 'some IP'
TCP_PORT = 5014

port = 0
port_main = 0
s = 0

FLAGS = tf.app.flags.FLAGS


tf.app.flags.DEFINE_string('train_dir', '/home/ubuntu/cifar10_train',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 5000,
                            """Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
                            """How often to log results to the console.""")
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.30)


def safe_recv(size, server_socket):
    data = ""
    temp = ""
    data = bytearray()
    recv_size = 0
    while 1:
        try:
            temp = server_socket.recv(size-len(data))
            data.extend(temp)
            recv_size = len(data)
            if recv_size >= size:
                break
        except:
            print("Error")
    data = bytes(data)
    return data


def train():
    """Train CIFAR-10 for a number of steps."""

    g1 = tf.Graph()
    with g1.as_default():
        global_step = tf.Variable(-1, name='global_step',
                                  trainable=False, dtype=tf.int32)
        increment_global_step_op = tf.assign(global_step, global_step+1)

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)
        grads = cifar10.train_part1(loss, global_step)

        only_gradients = [g for g, _ in grads]

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],
            config=tf.ConfigProto(
                # log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess:
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:
            global port
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect((TCP_IP, port_main))
            recv_size = safe_recv(17, s)
            recv_size = pickle.loads(recv_size)
            recv_data = safe_recv(recv_size, s)
            var_vals = pickle.loads(recv_data)
            s.close()
            feed_dict = {}
            i = 0
            for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                feed_dict[v] = var_vals[i]
                i = i+1
            print("Received variable values from ps")
            # Opening the socket and connecting to server
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect((TCP_IP, port))
            while not mon_sess.should_stop():
                gradients, step_val = mon_sess.run(
                    [only_gradients, increment_global_step_op], feed_dict=feed_dict)
                # sending the gradients
                send_data = pickle.dumps(gradients, pickle.HIGHEST_PROTOCOL)
                to_send_size = len(send_data)
                send_size = pickle.dumps(to_send_size, pickle.HIGHEST_PROTOCOL)
                s.sendall(send_size)
                s.sendall(send_data)
                # receiving the variable values
                recv_size = safe_recv(17, s)
                recv_size = pickle.loads(recv_size)
                recv_data = safe_recv(recv_size, s)
                var_vals = pickle.loads(recv_data)

                feed_dict = {}
                i = 0
                for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                    feed_dict[v] = var_vals[i]
                    i = i+1
            s.close()


def main(argv=None):  # pylint: disable=unused-argument
    global port
    global port_main
    global s
    if(len(sys.argv) != 3):
        print("<port> <worker-id> required")
        sys.exit()
    port = int(sys.argv[1]) + int(sys.argv[2])
    port_main = int(sys.argv[1])
    print("Connecting to port ", port)
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    total_start_time = time.time()
    train()
    print("--- %s seconds ---" % (time.time() - total_start_time))


if __name__ == '__main__':
    tf.app.run()

EDIT:

Here is the train_part1() code:

def train_part1(total_loss, global_step):
  """Train CIFAR-10 model.

  Create an optimizer and apply to all trainable variables. Add moving
  average for all trainable variables.

  Args:
    total_loss: Total loss from loss().
    global_step: Integer Variable counting the number of training steps
      processed.
  Returns:
    train_op: op for training.
  """
  # Variables that affect learning rate.
  num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
  decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)

  # Decay the learning rate exponentially based on the number of steps.
  lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                  global_step,
                                  decay_steps,
                                  LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True)
  tf.summary.scalar('learning_rate', lr)

  # Generate moving averages of all losses and associated summaries.
  loss_averages_op = _add_loss_summaries(total_loss)

  # Compute gradients.
  with tf.control_dependencies([loss_averages_op]):
    opt = tf.train.GradientDescentOptimizer(lr)
    grads = opt.compute_gradients(total_loss)

  return grads

Solution

  • To me it seems that line

    gradients, step_val = mon_sess.run(
                        [only_gradients, increment_global_step_op], feed_dict=feed_dict)
    

    receieves new values for variables in feed_dict, assign these values to variables, and makes a training step, during which it only calculates and returns the gradients, that are later sent to the parameter server. I would expect cifar10.train_part1 (the one that returns only_gradients) to depend on variable values and define the update.

    Update: I looked into the code and changed my mind. Had to google and found next answer that shed some light on what is happening.

    Gradients are actually not applied in this code anywhere implicitly. Instead, gradients are sent to the parameter server, parameter server averages gradients and applies them to weights, it returns the weights to the local worker, * recieved weights are used instead of local weights during session run through feed_dict* i.e. local weights are never actually updated and do not actually matter at all. The key, is that feed_dict allows to rewrite any tensor output of the session run and this code rewrites variables.