Search code examples
deep-learningartificial-intelligencecntk

do predictions using cntk checkpoint


These days, I have tried a model which implemented by cntk. But I can't find a way to predict new pic with trained model. The trained model saved as a checkpoint:

trainer.save_checkpoint(os.path.join(output_model_folder, "model_{}".format(best_epoch)))

Then I have gotten some files like:

enter image description here

So, I tried to load this model checkpoint like:

model = ct.load_model('../data/models/VGG13_majority/model_94')

the code above can run successfully. Then I tried

model.eval(image_data)

but I got an error:enter image description here

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ update ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

this time I have tried the method below:

model = ct.load_model('../data/models/VGG13_majority/model_94')
model.eval({model.arguments[0]: [final_image]})

then a new error raised:

enter image description here


Solution

  • For any C.Function.eval() you need to pass a dictionary as the argument.

    So it will go something like this, assuming that you only have one input_variable into the model:

    model = C.load_model()
    model.eval({model.arguments[0]: image_data})
    

    Anyhow, i noticed that you saved the model from the checkpoint. By doing so, you actually saved the "ground_truth" input_variable to the loss function too.

    I would recommend next time that you saved the model directly. Usually the files from save_checkpoint is meant to be used in restore_from_checkpoint()

    import cntk as C
    from cntk.layers import Dense
    
    model = Dense(10)(C.input_variable(1))
    loss = C.binary_cross_entropy(model, C.input_variable(10))
    
    trainer = C.Trainer(model, (loss,), [C.adam(model.parameters, 0.9, 0.9)])
    trainer.save_checkpoint("hello")
    model.save()  # used this to save the model directly
    
    # to recover model from checkpoint use below
    trainer.restore_from_checkpoint("hello")
    original_model = trainer.model
    print(trainer)
    for i in trainer.model.arguments:
        print(i)