Search code examples
tensorflowkeraslstmautoencoderseq2seq

how to set steps_per_epoch in varibale input length in fit_generator keras


I need to feed my input data to the model in such a way that sentences with the same length will be in the same batch(variable input length in LSTM).

My question is that, when we are using fit_generator we need to specify steps_per_epoch , validation_steps, but in my case I can not achieve that by simply num_train_steps = len(Xtrain) // BATCH_SIZE. Now my question is that where can I calculate that and pass it to fit_generator? I have steps_per_epoch in sentence_generator but I dont know how to pass it to fit_generator.

Is there any way we can return the length of each batch in sentence_generator?

This is the fit_generator (I don't know how to achieve num_train_steps and pass to fit_generator?)

lstm_ae_model.fit_generator(train_gen, val_gen, num_train_steps, num_val_steps, dir, NUM_EPOCHS=1)

So My custom generator is like this in case it can help:

def sentence_generator(X, embeddings):
    while True:
        # loop once per epoch
        index_sentence = 0
        import itertools
        items = sorted(X.values(), key=len, reverse=True)
        for length, dics in itertools.groupby(items, len):
            # dics is all the nested dictionaries with this length
            a = 0
            for x in dics:
                a = a+1
            num_train_steps = a
            sent_wids = np.zeros([a, length])
            for temp_sentence in dics:
                keys_words = list(temp_sentence.keys())
                for index_word in range(len(keys_words)):
                    sent_wids[index_sentence, index_word] = lookup_word2id(keys_words[index_word])
                index_sentence = index_sentence + 1
                Xbatch = embeddings[sent_wids]
                yield Xbatch, Xbatch

Solution

  • What you can do is first make a function that pre-computes the value of steps_per_epoch by iterating on the dataset and computing this value, and then pass it to fit_generator. Something like:

    def compute_steps(X):
        import itertools
        items = sorted(X.values(), key=len, reverse=True)
        count = 0
        for length, dics in itertools.groupby(items, len):
            count += 1
    
        return count
    
    spe = compute_steps(...)
    gen = sentence_generator(...)
    model.fit_generator(gen, steps_per_epoch=spe)
    

    And do similarly for validation data.