Search code examples
pythoncsvtensorflowbatching

Tensorflow - batching issues


I'm quite new to tensorflow, and I'm trying to train from my csv files using batch.

Here's my code for read csv file and make batch

filename_queue = tf.train.string_input_producer(
    ['BCHARTS-BITSTAMPUSD.csv'], shuffle=False, name='filename_queue')

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[0.], [0.], [0.], [0.], [0.],[0.],[0.],[0.]]
xy = tf.decode_csv(value, record_defaults=record_defaults)

# collect batches of csv in
train_x_batch, train_y_batch = \
    tf.train.batch([xy[0:-1], xy[-1:]], batch_size=100)

and here's for training :

# initialize
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)


# train my model
for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(2193 / batch_size)

    for i in range(total_batch):
        batch_xs, batch_ys = sess.run([train_x_batch, train_y_batch])
        feed_dict = {X: batch_xs, Y: batch_ys}
        c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)
        avg_cost += c / total_batch

    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))

coord.request_stop()
coord.join(threads)

Here's my questions :

1.

My csv file have 2193 records and my batching size is 100. So what I want is this : in every 'epoch' start with 'first record', and trains 21 batches with 100 records, and last 1 batch with 93 records. so total 22 batches.

However, I found that every batch has 100 size - even with the last one. Moreover, it does not start with 'first record' from second 'epoch'.

2.

How can I obtain records size(in this case, 2193)? Should I hard code it? Or is there other smart way to do it? I used tendor.get_shape().as_list() but it's not work for batch_xs. It just returns me empty shape [].


Solution

  • We recently added a new API to TensorFlow called tf.contrib.data that makes it easier to solve problems like this. (The "queue runner"–based APIs make it difficult to write computations on exact epoch boundaries, because the epoch boundary gets lost.)

    Here's an example of how you'd use tf.contrib.data to rewrite your program:

    lines = tf.contrib.data.TextLineDataset("BCHARTS-BITSTAMPUSD.csv")
    
    def decode(line):
      record_defaults = [[0.], [0.], [0.], [0.], [0.],[0.],[0.],[0.]]
      xy = tf.decode_csv(value, record_defaults=record_defaults)
      return xy[0:-1], xy[-1:]
    
    decoded = lines.map(decode)
    
    batched = decoded.batch(100)
    
    iterator = batched.make_initializable_iterator()
    
    train_x_batch, train_y_batch = iterator.get_next()
    

    Then the training part can become a bit simpler:

    # initialize
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    # train my model
    for epoch in range(training_epochs):
      avg_cost = 0
      total_batch = int(2193 / batch_size)
    
      total_cost = 0.0
      total_batch = 0
    
      # Re-initialize the iterator for another epoch.
      sess.run(iterator.initializer)
    
      while True:
    
        # NOTE: It is inefficient to make a separate sess.run() call to get each batch 
        # of input data and then feed it into a different sess.run() call. For better
        # performance, define your training graph to take train_x_batch and train_y_batch
        # directly as inputs.
        try:
          batch_xs, batch_ys = sess.run([train_x_batch, train_y_batch])
        except tf.errors.OutOfRangeError:
          break
    
        feed_dict = {X: batch_xs, Y: batch_ys}
        c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)
        total_cost += c
        total_batch += batch_xs.shape[0]
    
      avg_cost = total_cost / total_batch
    
      print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))
    

    For more details about how to use the new API, see the "Importing Data" programmer's guide.