Search code examples
pythontensorflowdeep-learningdeep-residual-networks

Tensorflow: How to set the learning rate in log scale and some Tensorflow questions


I am a deep learning and Tensorflow beginner and I am trying to implement the algorithm in this paper using Tensorflow. This paper uses Matconvnet+Matlab to implement it, and I am curious if Tensorflow has the equivalent functions to achieve the same thing. The paper said:

The network parameters were initialized using the Xavier method [14]. We used the regression loss across four wavelet subbands under l2 penalty and the proposed network was trained by using the stochastic gradient descent (SGD). The regularization parameter (λ) was 0.0001 and the momentum was 0.9. The learning rate was set from 10−1 to 10−4 which was reduced in log scale at each epoch.

This paper uses wavelet transform (WT) and residual learning method (where the residual image = WT(HR) - WT(HR'), and the HR' are used for training). Xavier method suggests to initialize the variables normal distribution with

stddev=sqrt(2/(filter_size*filter_size*num_filters)

Q1. How should I initialize the variables? Is the code below correct?

weights = tf.Variable(tf.random_normal[img_size, img_size, 1, num_filters], stddev=stddev)

This paper does not explain how to construct the loss function in details . I am unable to find the equivalent Tensorflow function to set the learning rate in log scale (only exponential_decay). I understand MomentumOptimizer is equivalent to Stochastic Gradient Descent with momentum.

Q2: Is it possible to set the learning rate in log scale?

Q3: How to create the loss function described above?

I followed this website to write the code below. Assume model() function returns the network mentioned in this paper and lamda=0.0001,

inputs = tf.placeholder(tf.float32, shape=[None, patch_size, patch_size, num_channels])
labels = tf.placeholder(tf.float32, [None, patch_size, patch_size, num_channels])

# get the model output and weights for each conv
pred, weights = model()

# define loss function
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=pred)

for weight in weights:
    regularizers += tf.nn.l2_loss(weight)

loss = tf.reduce_mean(loss + 0.0001 * regularizers)

learning_rate = tf.train.exponential_decay(???) # Not sure if we can have custom learning rate for log scale
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum).minimize(loss, global_step)

NOTE: As I am a deep learning/Tensorflow beginner, I copy-paste code here and there so please feel free to correct it if you can ;)


Solution

  • Other answers are very detailed and helpful. Here is a code example that uses placeholder to decay learning rate at log scale. HTH.

    import tensorflow as tf
    
    import numpy as np
    
    
    # data simulation
    N = 10000
    D = 10
    x = np.random.rand(N, D)
    w = np.random.rand(D,1)
    y = np.dot(x, w)
    
    print y.shape
    
    #modeling
    batch_size = 100
    tni = tf.truncated_normal_initializer()
    X = tf.placeholder(tf.float32, [batch_size, D])
    Y = tf.placeholder(tf.float32, [batch_size,1])
    W = tf.get_variable("w", shape=[D,1], initializer=tni)
    B = tf.zeros([1])
    
    lr = tf.placeholder(tf.float32)
    
    pred = tf.add(tf.matmul(X,W), B)
    print pred.shape
    mse = tf.reduce_sum(tf.losses.mean_squared_error(Y, pred))
    opt = tf.train.MomentumOptimizer(lr, 0.9)
    
    train_op = opt.minimize(mse)
    
    learning_rate = 0.0001
    
    do_train = True
    acc_err = 0.0
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    while do_train:
      for i in range (100000):
         if i > 0 and i % N == 0:
           # epoch done, decrease learning rate by 2
           learning_rate /= 2
           print "Epoch completed. LR =", learning_rate
    
         idx = i/batch_size + i%batch_size
         f = {X:x[idx:idx+batch_size,:], Y:y[idx:idx+batch_size,:], lr: learning_rate}
         _, err = sess.run([train_op, mse], feed_dict = f)
         acc_err += err
         if i%5000 == 0:
           print "Average error = {}".format(acc_err/5000)
           acc_err = 0.0