I'm trying to extract a column from my database, apply a transformation, and create a new column with the results.
I ultimately want to save the local variable 'new_proba' (which has a length of 740, the same length as my database) as a new column called 'predict_proba_tplus1'. From reading this thread, I've learned that the UPDATE function expects a tuple.
I created an 'IDs' list which matches the indexID column in the database and then zipped it with 'new_proba' to create the tuple which outputs '(0.56298709097028454, 0), (0.54392926856501334, 1),' etc.
The function below doesn't actually throw any error, but it only creates the column 'predict_proba_tplus1' and doesn't fill in any values which I'd expect c.executemany() to do - I'm left with a new column filled with NULL values. It makes me think there's something wrong with the WHERE statement - i.e. it's not matching the indexID column to the IDs variable for some reason, despite the numbers being the same.
Any insight would be hugely appreciated.
def update_class_proba(path):
conn = sqlite3.connect(path)
c = conn.cursor()
cursor = c.execute('SELECT text, indexID FROM reuters_test_X')
all_rows = cursor.fetchall()
X = vect.transform(x[0] for x in all_rows)
new_proba = list(clf.predict_proba(X)[:,1])
IDs = list(np.arange(0, 740, 1))
new_proba_tuple = list(zip(new_proba,IDs))
c.execute('ALTER TABLE reuters_test_X ADD COLUMN predict_proba_tplus1 REAL')
c.executemany('UPDATE reuters_test_X SET predict_proba_tplus1=? WHERE indexID=?', new_proba_tuple)
conn.commit()
conn.close()
Those values aren't plain float
; they're numpy.float64
, which the database can't handle.
Convert your values to plain float
and int
like this:
new_proba = list(float(z) for z in clf.predict_proba(X)[:,1])
IDs = list(int(zz) for zz in np.arange(0, 740, 1))