Search code examples
pythonpandasnumpymlogit

Calculating multinomial logit model prediction probabilities


Please try to give parameterize solution (there are more than three alternatives).

I have a dict with beta values:

{'B_X1': 2.0, 'B_X2': -3.0}

And this data frame:

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789
   6.75    4.69    9.59    5.52    9.69    7.40
   7.46    4.94    3.01    1.78    1.38    4.68
   2.05    7.30    4.08    7.02    8.24    8.49
   5.60    7.88    8.11    5.98    4.60    1.39
   1.80    8.28    9.16    7.34    7.69    6.16
   3.73    6.93    8.93    2.58    3.48    6.04
   8.06    8.88    7.06    6.76    4.68    7.82
   5.00    7.29    5.86    3.92    5.67    4.10
   2.49    2.55    4.66    7.15    6.26    7.87
   1.50    3.35    5.70    9.86    4.83    1.17
   8.19    7.72    9.56    6.61    4.15    3.64
   2.43    9.54    9.15    4.41    9.18    7.85
   2.71    3.24    4.56    6.22    7.89    9.93
   5.96    4.34    5.26    8.63    9.81    9.40

123, 456, and 789 are the alternatives.

I want to calculate the prediction probability using this formula: enter image description here

j, k, and s are the mentioned alternatives.

Expected result:

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
   6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
   7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
   2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
   5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
   1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
   3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
   8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
   5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
   2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
   1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
   8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
   2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
   2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
   5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024

Probabilities sum should be 1 in every row.

Please try to give parameterize solution (there are more than three alternatives).

Expected result with constant for each alternative: {'B_X1': 2.0, 'B_X2': -3.0, 'B_123': 0.1, 'B_456': 0.2, 'B_789': 0.3}

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
   6.75    4.69    9.59    5.52    9.69    7.40  0.440  0.000  0.560
   7.46    4.94    3.01    1.78    1.38    4.68  0.977  0.023  0.000
   2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
   5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
   1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
   3.73    6.93    8.93    2.58    3.48    6.04  0.021  0.952  0.027
   8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
   5.00    7.29    5.86    3.92    5.67    4.10  0.180  0.102  0.717
   2.49    2.55    4.66    7.15    6.26    7.87  0.034  0.604  0.363
   1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
   8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
   2.43    9.54    9.15    4.41    9.18    7.85  0.034  0.034  0.932
   2.71    3.24    4.56    6.22    7.89    9.93  0.978  0.021  0.001
   5.96    4.34    5.26    8.63    9.81    9.40  0.970  0.001  0.029

Solution

  • IIUC:

    Turn columns into a MultiIndex

    df = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
    

    And define your B such that the keys match the prefixes in df

    B = {'X1': 2.0, 'X2': -3.0}
    

    Then

    def f(b, x):
        return np.exp((b * x).sum(1))
    
    parts = f(B, df.stack()).unstack()
    
    preds = parts.div(parts.sum(1), axis=0)
    
    df.join(pd.concat({'P': preds}, axis=1).round(3)).pipe(
        lambda d: d.set_axis(map('_'.join, d.columns), axis=1, inplace=False)
    )
    
        X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
    0     6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
    1     7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
    2     2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
    3     5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
    4     1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
    5     3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
    6     8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
    7     5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
    8     2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
    9     1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
    10    8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
    11    2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
    12    2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
    13    5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024
    

    Wrapped in one pretty function

    def f(df, b):
        d = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
        parts = np.exp(d.stack().mul(b).sum(1).unstack())
        preds = pd.concat({'P': parts.div(parts.sum(1), axis=0)}, axis=1).round(3)
        d = d.join(preds)
        d.columns = list(map('_'.join, d.columns))
        return d
    
    f(df, B)
    
        X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
    0     6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
    1     7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
    2     2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
    3     5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
    4     1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
    5     3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
    6     8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
    7     5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
    8     2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
    9     1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
    10    8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
    11    2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
    12    2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
    13    5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024