Search code examples
pythontensorflowtflearn

Tensorflor/TFLearn Cannot feed value of shape


I want to do the tensorflow example "Boston housing prices" in TFLearn. But I get shape error.

Here is my code:

    import tflearn   
    from tflearn.data_utils import load_csv

    data, target = load_csv('boston_train.csv', has_header=True)  
    input_ = tflearn.input_data(shape=[None, 9])   
    linear = tflearn.fully_connected(input_, 9)   
    regression = tflearn.regression(linear, optimizer='sgd', loss='mean_square', learning_rate=0.01)     
    m = tflearn.DNN(regression)
    m.fit(data, target, n_epoch=10, batch_size=10, show_metric=True)

I get following error:

    ValueError: Cannot feed value of shape (10,) for Tensor 'TargetsData/Y:0', which has shape '(?, 9)'

The csv file has 9 features and one label column.
What should I do?


Solution

  • Thanks for the answer!

    I got the problem solved, so here is the code:

    import numpy as np
    import tflearn
    from tflearn.data_utils import load_csv
    from numpy import genfromtxt
    
    
    data, target = load_csv('boston_train.csv', has_header=True)
    target = np.reshape(target, (-1,1))
    
    net = tflearn.input_data(shape=[None, 9])
    net = tflearn.fully_connected(net, 9)
    net = tflearn.fully_connected(net, 1)
    net = tflearn.regression(net, optimizer='sgd', loss='mean_square', learning_rate=0.01)
    net = tflearn.DNN(net)
    net.fit(data, target, n_epoch=10, batch_size=10, show_metric=True)
    
    test_data = genfromtxt('boston_predict.csv', delimiter=',', skip_header = 1)
    test_data = np.reshape(test_data, (-1,9))
    
    pred = net.predict(test_data)
    print(pred)