Search code examples
pythonpandasnumpyscipynumpy-einsum

Numpy einsum behaving badly. What to look out for?


What is typically failing when numpy einsum throws the error:

Traceback (most recent call last):
  File "rmse_iter.py", line 30, in <module>
    rmse_out = np.sqrt(np.einsum('ij,ij->i',diffs,diffs)/3.0)
TypeError: invalid data type for einsum

The numpy array diff is produced from a subtraction of two pandas dataframes, and contains only numbers of type np.float32 -- no strings, nan, +/-inf, or any other such funny business. So what should I be looking for? Under what circumstances does einsum typically fail this way?

This is how I'm loading and processing the dataframe:

df = pd.read_pickle(fn)
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df.dropna(inplace=True)
a = df.values
diffs = a[:,2:27] - a[:,27:]
rmse_out = np.sqrt(np.einsum('ij,ij->i',diffs,diffs)/3.0)

Please excuse the open-endedness of the question. Thanks to Divakar for introducing me to einsum wizardry.

edit:

here is my attempt at including actual data in tabular form:

        rna     cnv     1_a     2_a     3_a     4_a     5_a     6_a     7_a     8_a     9_a     10_a    11_a    12_a    13_a    14_a    15_a    16_a    17_a    18_a    19_a    20_a    21_a    22_a    23_a    24_a    25_a    1_b     2_b     3_b     4_b     5_b     6_b     7_b     8_b     9_b     10_b    11_b    12_b    13_b    14_b    15_b    16_b    17_b    18_b    19_b    20_b    21_b    22_b    23_b    24_b    25_b
5641095 AP1G1   CCL8    3.588543653488159       10.119391441345215      32.92853546142578       6.307891368865967       -32.6164665222168       -34.94172286987305      -4.913632869720459
      -0.1798282265663147     -0.5144565105438232     12.70481014251709       -37.560791015625        39.83904266357422       32.92853546142578       -0.9303828477859497     -32.6164665222168       -8.661237716674805      31.074113845825195      -0.1798282265663147     -0.5144565105438232     -4.566867828369141      -2.5914463996887207     10.119391441345215      -12.007019996643066     6.307891368865967       -21.65423583984375      -8.217794418334961      2.9316258430480957      27.942243576049805      11.107816696166992      -7.4105706214904785     -1.1366562843322754     17.06450653076172       -7.277851581573486      7.186253547668457       -37.862789154052734     2.21020770072937        -14.829334259033203     5.599830627441406       27.80745506286621       -5.512645244598389      -1.1366562843322754     17.06450653076172       -20.73367691040039      -8.826581001281738      -10.555018424987793     -8.217794418334961
      -6.360044956207275      -1.9607794284820557     6.345422267913818       13.062686920166016
5641105 AP1G1   CCND2   2.3494300842285156      10.119391441345215      27.10674476623535       3.8083128929138184      -70.73456573486328      -39.372581481933594     -8.208958625793457
      -0.1798282265663147     1.082576036453247       12.70481014251709       -63.872154235839844     39.83904266357422       27.10674476623535       0.01608092524111271     -70.73456573486328      -8.661237716674805      43.937278747558594      -0.1798282265663147     1.082576036453247       -3.672504425048828      -3.3072872161865234     10.119391441345215      -8.377813339233398      3.8083128929138184      -26.24537467956543      -10.137262344360352     2.9316258430480957      15.313714027404785      7.0047502517700195      -12.949808120727539     -2.3481321334838867     12.740055084228516      -3.4322025775909424     8.920576095581055       -62.727718353271484     0.2877853512763977      -19.20431137084961      11.22409725189209       27.80745506286621       -1.9983365535736084     -2.3481321334838867     12.740055084228516      -33.702674865722656     -8.826581001281738      -18.610857009887695     -10.137262344360352
     -6.804142475128174      -0.43901631236076355    18.789241790771484      15.554900169372559
5641113 AP1G1   CCNH    4.718714237213135       1230632818573312.0      27.10674476623535       4.7800703048706055      -70.73456573486328      -47.087345123291016     -6.196646690368652
      -1.9009416103363037     474487485104128.0       25.461158752441406      -90.02267456054688      39.83904266357422       27.10674476623535       0.7240228652954102      -70.73456573486328      -14.690686225891113     53.84657669067383       -1.9009416103363037     474487485104128.0       -4.566867828369141      -555133515595776.0      1230632818573312.0      -328591573254144.0      4.7800703048706055      -1088045541490688.0     -10.137262344360352     2.9316258430480957      19.262754440307617      11.107816696166992      -12.949808120727539     -2.3481321334838867     17.06450653076172       -7.277851581573486      17.50507164001465       -45.33726501464844      0.9687032103538513      -33.4061164855957       8564995327524864.0      38.147640228271484      -3.5528361797332764     -2.3481321334838867     17.06450653076172       -33.702674865722656     -8.826581001281738      -27.176956176757812     -10.137262344360352
     -6.431360721588135      -0.43901631236076355    3244183414374400.0      15.554900169372559

Solution

  • Turns out that extracting values from a df with a = df.values doesn't allow for coersion of strings to np.nan, which were apparently in my original df. This is why my attempt to typecast then slice all those values from the array created from df.values failed -- the items just stayed as "object".

    To fix this, I simply selected numeric columns from the original df and sent them to a matrix:

    a= df[df.columns[2:]].as_matrix()
    

    Then I made sure to update the indexes in the diff operation, as the column indexes moved back by two:

    diffs = a[:,:25] - a[:,25:]
    

    The takeaway: when einsum is behaving badly, look for strings or "objects" in your array that are otherwise not float32 or float64.