Search code examples
pythonlistsortingmachine-learningsampling

How do you sort lists of tuples based on the count of a specific value?


I am working on a NER problem—hence the BIO tagging—with a very small dataset, and I am manually splitting it into train, validation, and test data. Thus, to make the first of two splits, I need to sort lists of tuples into two lists based on the count of 'B' in data.

I am shuffling data, so the output varies, but it typically yeilds what I provide below. data can be split such that a total count of 10 instances of 'B' is possible in bin_1. So it's not that data won't split this way given the way B is distributed through the lists of tuples.

How do I get the split that I am after? For this example, and the desired split, I want the total count of 'B' in bin_1 to be 10, but it's always over.

Assistance would be much appreciated.

Data:

data = [[('a', 'B'), ('b', 'I'), ('c', 'O'), ('d', 'B'), ('e', 'I'), ('f', 'O')],
        [('g', 'O'), ('h', 'O')],
        [('i', 'B'), ('j', 'I'), ('k', 'O')],
        [('l', 'B'), ('m', ''), ('n', 'B'), ('o', 'O')],
        [('p', 'O'), ('q', 'O'), ('r', 'O')],
        [('s', 'B'), ('t', 'O')],
        [('u', 'O'), ('v', 'B'), ('w', 'I'), ('x', 'O'), ('y', 'O')],
        [('z', 'B')],
        [('a', 'B'), ('b', 'I'), ('c', 'O')],
        [('d', 'O')],
        [('e', 'O'), ('f', 'O')],
        [('g', 'O'), ('h', 'B')],
        [('i', 'B'), ('j', 'I')],
        [('k', 'O')],
        [('l', 'O'), ('m', 'O'), ('n', 'O'), ('o', 'O')],
        [('p', 'O'), ('q', 'O'), ('r', 'O'), ('s', 'B'), ('t', 'O')],
        [('u', 'O'), ('v', 'B'), ('w', 'I'), ('x', 'O'), ('y', 'O'), ('z', 'B')]]

Current code:

split = 0.7
d = []
total_B = 0
bin_1 = []
bin_2 = []
counter = 0

random.shuffle(data)

for f in data:
    cnt = {}
    for _, label in f:
        if label in cnt:
            cnt[label] += 1
        else:
            cnt[label] = 1
    d.append(cnt)

for f in d:
    total_B += f.get('B', 0)

for f,g in zip(d, data):
    if f.get('B') is not None:
        if counter <= round(total_B * split):
            counter += f.get('B')
            bin_1.append(g)
        else:
            bin_2.append(g)

print(round(total_B * split))
print(sum(1 for sublist in bin_1 for tuple_item in sublist if tuple_item[1] == 'B'))
print(sum(1 for sublist in bin_2 for tuple_item in sublist if tuple_item[1] == 'B'))

Current output:

Total count of 'B' in 'bin_1' should be: 10
Total count of 'B' in 'bin_1' is': 11
Total count of 'B' in 'bin_2' is': 3
bin_1, bin_2
>>>
[[('a', 'B'), ('b', 'I'), ('c', 'O')],
  [('g', 'O'), ('h', 'B')],
  [('i', 'B'), ('j', 'I'), ('k', 'O')],
  [('u', 'O'), ('v', 'B'), ('w', 'I'), ('x', 'O'), ('y', 'O'), ('z', 'B')],
  [('s', 'B'), ('t', 'O')],
  [('l', 'B'), ('m', ''), ('n', 'B'), ('o', 'O')],
  [('a', 'B'), ('b', 'I'), ('c', 'O'), ('d', 'B'), ('e', 'I'), ('f', 'O')],
  [('i', 'B'), ('j', 'I')]],
 [[('u', 'O'), ('v', 'B'), ('w', 'I'), ('x', 'O'), ('y', 'O')],
  [('z', 'B')],
  [('p', 'O'), ('q', 'O'), ('r', 'O'), ('s', 'B'), ('t', 'O')]]

Desired output:

Total count of 'B' in 'bin_1' should be: 10
Total count of 'B' in 'bin_1' is': 10
Total count of 'B' in 'bin_2' is': 4

Solution

  • One possible solution is to get the distribution of the 'B' among indexes of your data. Let's say data was shuffled already, make use of:

    def get_distribution(data):
        return {i: len([x for x in t if (x[1] == 'B')]) for i, t in enumerate(data) }
    

    For data you get:

    distribution = get_distribution(data)
    print(distribution)
    #=> {0: 2, 1: 0, 2: 1, 3: 2, 4: 0, 5: 1, 6: 1, 7: 1, 8: 1, 9: 0, 10: 0, 11: 1, 12: 1, 13: 0, 14: 0, 15: 1, 16: 2}
    

    Now, iterate over distribution and fill your bins. You can develop a more complex algorithm, this is the simplest:

    bin_1 = []
    bin_2 = []
    ratio = 0.7
    count = 0
    total = sum(distribution.values())
    
    for k, v in distribution.items():
        if count/total < ratio:
            bin_1.append(data[k])
            count += v
        else:
            bin_2.append(data[k])
    

    So, check:

    print(bin_1)
    print(bin_2)
    distr_bin_1 = get_distribution(bin_1)
    distr_bin_2 = get_distribution(bin_2)
    print(distr_bin_1)
    print(distr_bin_2)
    count_bin_1 = sum(distr_bin_1.values())
    count_bin_2 = sum(distr_bin_2.values())
    print(count_bin_1/(count_bin_1 + count_bin_2)) # actual ratio