Search code examples
pythonsqlitesql-updateexecutemany

Python sqlite3 'executemany' not successfully updating my database


I'm trying to extract a column from my database, apply a transformation, and create a new column with the results.

I ultimately want to save the local variable 'new_proba' (which has a length of 740, the same length as my database) as a new column called 'predict_proba_tplus1'. From reading this thread, I've learned that the UPDATE function expects a tuple.

I created an 'IDs' list which matches the indexID column in the database and then zipped it with 'new_proba' to create the tuple which outputs '(0.56298709097028454, 0), (0.54392926856501334, 1),' etc.

The function below doesn't actually throw any error, but it only creates the column 'predict_proba_tplus1' and doesn't fill in any values which I'd expect c.executemany() to do - I'm left with a new column filled with NULL values. It makes me think there's something wrong with the WHERE statement - i.e. it's not matching the indexID column to the IDs variable for some reason, despite the numbers being the same.

Any insight would be hugely appreciated.

def update_class_proba(path):
    conn = sqlite3.connect(path)
    c = conn.cursor()
    cursor = c.execute('SELECT text, indexID FROM reuters_test_X')
    all_rows = cursor.fetchall()
    X = vect.transform(x[0] for x in all_rows)
    new_proba = list(clf.predict_proba(X)[:,1])
    IDs = list(np.arange(0, 740, 1))
    new_proba_tuple = list(zip(new_proba,IDs))
    c.execute('ALTER TABLE reuters_test_X ADD COLUMN predict_proba_tplus1 REAL')
    c.executemany('UPDATE reuters_test_X SET predict_proba_tplus1=? WHERE indexID=?', new_proba_tuple)
    conn.commit()
    conn.close()

Solution

  • Those values aren't plain float; they're numpy.float64, which the database can't handle.

    Convert your values to plain float and int like this:

    new_proba = list(float(z) for z in clf.predict_proba(X)[:,1])
    IDs = list(int(zz) for zz in np.arange(0, 740, 1))