Search code examples
pythonparsingcsvtext-parsingpython-itertools

Disturbing odd behavior/bug in Python itertools groupby?


I am using itertools.groupby to parse a short tab-delimited textfile. the text file has several columns and all I want to do is group all the entries that have a particular value x in a particular column. The code below does this for a column called name2, looking for the value in variable x. I tried to do this using csv.DictReader and itertools.groupby. In the table, there are 8 rows that match this criteria so 8 entries should be returned. Instead groupby returns two sets of entries, one with a single entry and another with 7, which seems like the wrong behavior. I do the matching manually below on the same data and get the right result:

import itertools, operator, csv
col_name = "name2"
x = "ENSMUSG00000002459"
print "looking for entries with value %s in column %s" %(x, col_name)
print "groupby gets it wrong: "
data = csv.DictReader(open(f), delimiter="\t", fieldnames=fieldnames)
for name, entries in itertools.groupby(data, key=operator.itemgetter(col_name)):
    if name == "ENSMUSG00000002459":
        wrong_result = [e for e in entries]
        print "wrong result has %d entries" %(len(wrong_result))
print "manually grouping entries is correct: "
data = csv.DictReader(open(f), delimiter="\t", fieldnames=fieldnames)
correct_result = []
for row in data:
    if row[col_name] == "ENSMUSG00000002459":
        correct_result.append(row)
print "correct result has %d entries" %(len(correct_result))

The output I get is:

looking for entries with value ENSMUSG00000002459 in column name2
groupby gets it wrong: 
wrong result has 7 entries
wrong result has 1 entries
manually grouping entries is correct: 
correct result has 8 entries

what is going on here? If groupby is really grouping, it seems like I should only get one set of entries per x, but instead it returns two. I cannot figure this out. EDIT: Ah got it it should be sorted.


Solution

  • You're going to want to change your code to force the data to be in key order...

    data = csv.DictReader(open(f), delimiter="\t", fieldnames=fieldnames)
    sorted_data = sorted(data, key=operator.itemgetter(col_name))
    for name, entries in itertools.groupby(data, key=operator.itemgetter(col_name)):
        pass # whatever
    

    The main use though, is when the datasets are large, and the data is already in key order, so when you have to sort anyway, then using a defaultdict is more efficient

    from collections import defaultdict
    name_entries = defaultdict(list)
    for row in data:
        name_entries[row[col_name]].append(row)