Search code examples
pythonmachine-learningmathcomputational-geometry

How to plot a separator line between two data classes?


I have a simple exercise that I am not sure how to do. I have the following data sets:

male100

    Year    Time
0   1896    12.00
1   1900    11.00
2   1904    11.00
3   1906    11.20
4   1908    10.80
5   1912    10.80
6   1920    10.80
7   1924    10.60
8   1928    10.80
9   1932    10.30
10  1936    10.30
11  1948    10.30
12  1952    10.40
13  1956    10.50
14  1960    10.20
15  1964    10.00
16  1968    9.95
17  1972    10.14
18  1976    10.06
19  1980    10.25
20  1984    9.99
21  1988    9.92
22  1992    9.96
23  1996    9.84
24  2000    9.87
25  2004    9.85
26  2008    9.69

and the second one:

female100

    Year    Time
0   1928    12.20
1   1932    11.90
2   1936    11.50
3   1948    11.90
4   1952    11.50
5   1956    11.50
6   1960    11.00
7   1964    11.40
8   1968    11.00
9   1972    11.07
10  1976    11.08
11  1980    11.06
12  1984    10.97
13  1988    10.54
14  1992    10.82
15  1996    10.94
16  2000    11.12
17  2004    10.93
18  2008    10.78

I have the following code:

y = -0.014*male100['Year']+38

plt.plot(male100['Year'],y,'r-',color = 'b')
ax = plt.gca() # gca stands for 'get current axis'
ax = male100.plot(x=0,y=1, kind ='scatter', color='g', label="Mens 100m", ax = ax)
female100.plot(x=0,y=1, kind ='scatter', color='r', label="Womens 100m", ax = ax)

Which produces this result:

enter image description here

I need to plot a line that would go exactly between them. So the line would leave all of the green points below it, and the red point above it. How do I do so?

I've tried playing with the parameters of y, but to no avail. I also tried fitting a linear regression to male100 , female100 , and the merged version of them (across rows), but couldn't get any results.

Any help would be appreciated!


Solution

  • A solution is using support vector machine (SVM). You can find two margins that separate two classes of points. Then, the average line of two support vectors is your answer. Notice that it's happened just when these two set of points are linearly separable. enter image description here
    You can use the following code to see the result:

    Data Entry

    male = [
    (1896  ,  12.00),
    (1900  ,  11.00),
    (1904  ,  11.00),
    (1906  ,  11.20),
    (1908  ,  10.80),
    (1912  ,  10.80),
    (1920  ,  10.80),
    (1924  ,  10.60),
    (1928  ,  10.80),
    (1932  ,  10.30),
    (1936  ,  10.30),
    (1948  ,  10.30),
    (1952  ,  10.40),
    (1956  ,  10.50),
    (1960  ,  10.20),
    (1964  ,  10.00),
    (1968  ,  9.95),
    (1972  ,  10.14),
    (1976  ,  10.06),
    (1980  ,  10.25),
    (1984  ,  9.99),
    (1988  ,  9.92),
    (1992  ,  9.96),
    (1996  ,  9.84),
    (2000  ,  9.87),
    (2004  ,  9.85),
    (2008  ,  9.69)
            ]
    female = [
    (1928,    12.20),
    (1932,    11.90),
    (1936,    11.50),
    (1948,    11.90),
    (1952,    11.50),
    (1956,    11.50),
    (1960,    11.00),
    (1964,    11.40),
    (1968,    11.00),
    (1972,    11.07),
    (1976,    11.08),
    (1980,    11.06),
    (1984,    10.97),
    (1988,    10.54),
    (1992,    10.82),
    (1996,    10.94),
    (2000,    11.12),
    (2004,    10.93),
    (2008,    10.78)
    ]
    

    Main Code

    Notice that the value of C is important here. If it is selected to 1, you can't get the preferred result.

    from sklearn import svm
    import numpy as np
    import matplotlib.pyplot as plt
    
    X = np.array(male + female)
    Y = np.array([0] * len(male) + [1] * len(female))
    
    # fit the model
    clf = svm.SVC(kernel='linear', C=1000) # C is important here
    clf.fit(X, Y)
    plt.figure(figsize=(8, 4))
    # get the separating hyperplane
    w = clf.coef_[0]
    a = -w[0] / w[1]
    xx = np.linspace(-1000, 10000)
    yy = a * xx - (clf.intercept_[0]) / w[1]
    plt.figure(1, figsize=(4, 3))
    plt.clf()
    plt.plot(xx, yy, "k-") #********* This is the separator line ************
    
    plt.scatter(X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.cm.Paired,
     edgecolors="k")
    plt.xlim((1890, 2010))  
    plt.ylim((9, 13)) 
    plt.show()