Search code examples
pythonsvmgradient-descent

Compare performance of the SVM with and without stochastic gradient descent


I'd like to compare the performance of the SVM classifier with and without the stochastic gradient descent. In the sklearn I've only found the SGDClassifier (that I can put into a pipeline). Doesn't sklearn provide an implementation of a non-stochastic (batch) gradient descent classifier? Do I have to implement both classifiers on my own in order to conduct an analysis?


Solution

  • When SVMs and SGD can't be combined

    SVMs are often used in combination with the kernel trick, which enables classification of non-linearly separable data. This answer explains why you wouldn't use stochastic gradient descent to solve a kernelised SVM: https://stats.stackexchange.com/questions/215524/is-gradient-descent-possible-for-kernelized-svms-if-so-why-do-people-use-quadr

    Linear SVMs

    If we stick to Linear SVMs, then we can run an experiment using sklearn, as it provides wrappers over libsvm (SVC), liblinear (LinearSVC) and also offers the SGDClassifier. Recommend reading the linked documentation of libsvm and liblinear to understand what is happening under the hood.

    Comparison on example dataset

    Below is a comparison of computational performance and accuracy over a randomly generated dataset (which may not be representative of your problem). You should alter the problem to fit your requirements.

    import time
    import numpy as np
    import matplotlib.pyplot as plt
    
    from sklearn.svm import SVC, LinearSVC
    from sklearn.linear_model import SGDClassifier
    from sklearn.model_selection import train_test_split
    
    # Randomly generated dataset
    # Linear function + noise
    np.random.seed(0)
    X = np.random.normal(size=(50000, 10))
    coefs = np.random.normal(size=10)
    epsilon = np.random.normal(size=50000)
    y = (X @ coefs + epsilon) > 0
    
    # Classifiers to compare
    algos = {
        'LibSVM': {
            'model': SVC(),
            'max_n': 4000,
            'time': [],
            'error': []
        },
        'LibLinear': {
            'model': LinearSVC(dual=False),
            'max_n': np.inf,
            'time': [],
            'error': []
        },
        'SGD': {
            'model': SGDClassifier(max_iter=1000, tol=1e-3),
            'max_n': np.inf,
            'time': [],
            'error': []
        }
    }
    
    splits = list(range(100, 1000, 100)) + \
             list(range(1500, 5000, 500)) + \
             list(range(6000, 50000, 1000))
    for i in splits:
        X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                            test_size=1-i/50000,
                                                            random_state=0)
        for k, v in algos.items():
            if i < v['max_n']:
                model = v['model']
                t0 = time.time()
                model.fit(X_train, y_train)
                t1 = time.time()
                v['time'].append(t1 - t0)
                preds = model.predict(X_test)
                e = (preds != y_test).sum() / len(y_test)
                v['error'].append(e)
    

    Plotting the results, we see that the traditional libsvm solver cannot be used on large n, while the liblinear and SGD implementations scale well computationally.

    plt.figure()
    for k, v in algos.items():
        plt.plot(splits[:len(v['time'])], v['time'], label='{} time'.format(k))
    plt.legend()
    plt.semilogx()
    plt.title('Time comparison')
    plt.show()
    

    enter image description here

    Plotting the error, we see that SGD is worse than LibSVM for the same training set, but if you have a large training set this becomes a minor point. The liblinear algorithm performs best on this dataset:

    plt.figure()
    for k, v in algos.items():
        plt.plot(splits[:len(v['error'])], v['error'], label='{} error'.format(k))
    plt.legend()
    plt.semilogx()
    plt.title('Error comparison')
    plt.xlabel('Number of training examples')
    plt.ylabel('Error')
    plt.show()
    

    enter image description here