Search code examples
pythonnumpymatplotlibplotlogistic-regression

Plot an array of strings numpy and matplotlib


I am trying to plot a logistic regression model using Matplotlib and Numpy
here is my code,

X = [[181, 80, 44], [177, 70, 43], [160, 60, 38], [154, 54, 37], [166, 65, 40]]

Y = ['male', 'male', 'female', 'female', 'male']

I have tried this but not working as expected.

Y_label = []
for x in range(0,len(Y)):
    if Y[x] == 'male': 
        Y_label.append('1')
    else : Y_label.append('0')


fit = np.polyfit(X,Y_label,1)
fit_fn = np.poly1d(fit) 
# fit_fn is now a function which takes in x and returns an estimate for y

plt.plot(X,Y_label, 'yo', X, fit_fn(X), '--k')
plt.xlim(0, 5)
plt.ylim(0, 12)
plt.show()

While running this code am getting an error

Traceback (most recent call last):
  File "/home/logistic_regression.py", line 27, in <module>
    fit = np.polyfit(X,Y_label,1)
  File "/usr/lib/python2.7/dist-packages/numpy/lib/polynomial.py", line 543, in polyfit
    y = NX.asarray(y) + 0.0
TypeError: unsupported operand type(s) for +: 'numpy.ndarray' and 'float'

Help me to solve this.
Thanks in advance.


Solution

  • I have changed your ployfit function (since it cannot work for >1D data) into logistic regression from sklearn. we have to go for 3D plot since the X is three dimensional. I have given green color if our prediction is right and red otherwise.

    Also I would recommend using label encoder from sklearn Y_label.

    import numpy as np
    X = np.array([[181, 80, 44], [177, 70, 43], [160, 60, 38], [154, 54, 37], [166, 65, 40]])
    
    Y = ['male', 'male', 'female', 'female', 'male']
    
    from mpl_toolkits.mplot3d import Axes3D
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    
    Y_label = []
    for x in range(0,len(Y)):
        if Y[x] == 'male': 
            Y_label.append(1)
        else : Y_label.append(0)
    
    from sklearn.linear_model import LogisticRegression
    reg = LogisticRegression().fit(X, Y_label)
    
    crt_pred = Y_label ==reg.predict(X)
    ax.scatter3D(X[crt_pred,0],X[crt_pred,1],X[crt_pred,2],s=50,c='g')
    ax.scatter3D(X[~crt_pred,0],X[~crt_pred,1],X[~crt_pred,2],s=50,c='r')
    
    plt.show()
    

    enter image description here

    For more understanding, go through this link