Search code examples
pythonlistloopszip

looping through multiple lists to calculate probability scores


I have 4 lists similar to the below lists. These lists contain the labels predictions from 4 different classifiers. As it is a multilabel classification problem, each instance receives a maximum of 5 labels (1 means the label is present and 0 not present).

a = [[0, 1, 1, 1, 1],  [1, 0, 0, 1, 1],  [0, 1, 0, 0, 1]]
b = [[0, 0, 0, 1, 1],  [0, 0, 0, 1, 0],  [1, 0, 1, 0, 1]]
c = [[1, 0, 1, 1, 1],  [0, 0, 0, 1, 1],  [0, 0, 1, 0, 1]]
d = [[1, 0, 1, 1, 1],  [0, 0, 0, 1, 0],  [0, 1, 1, 0, 1]]

I want to calculate probability for each of the labels to be assigned to each of the instances based on the predictions of the 4 classifiers, i.e., probability of receiving a label. To achieve my goal, I tried the below code:

probs = []
for p in zip(a, b, c, d):
    for s in p:
        x = (sum(s)) / 4
        count.append(x)

print(probs)

with this code I get the bellow output which is not what I need.

[1.0, 0.5, 1.0, 1.0, 0.75, 0.25, 0.5, 0.25, 0.5, 0.75, 0.5, 0.75]

my goal output is something like:

probs = [[0.5,0.25,0.75,1.0,1.0], [0.25,1.0,1.0,1.0, 0.5], 
[0.25,0.5,0.75,1.0,1.0]]

note that I if all the classifiers assigned the same label to a certain instance, then the probability should be 1.0 regardless whether they are 0 or 1 (present or not present). I know this is not complicated to do, my I just can't figure out how to do this. Thanks!


Solution

  • Try:

    a = [[0, 1, 1, 1, 1], [1, 0, 0, 1, 1], [0, 1, 0, 0, 1]]
    b = [[0, 0, 0, 1, 1], [0, 0, 0, 1, 0], [1, 0, 1, 0, 1]]
    c = [[1, 0, 1, 1, 1], [0, 0, 0, 1, 1], [0, 0, 1, 0, 1]]
    d = [[1, 0, 1, 1, 1], [0, 0, 0, 1, 0], [0, 1, 1, 0, 1]]
    
    
    out = []
    for p in zip(a, b, c, d):
        out.append([])
        for r in zip(*p):
            v = sum(r) / 4
            out[-1].append(1.0 if v == 0.0 else v)
    
    print(out)
    

    Prints:

    [
        [0.5, 0.25, 0.75, 1.0, 1.0],
        [0.25, 1.0, 1.0, 1.0, 0.5],
        [0.25, 0.5, 0.75, 1.0, 1.0],
    ]