I have a CSV file with about 50 million rows and I'm trying to manipulate the data and write to a new CSV file. Here's the code below:
import csv
import itertools
def main():
with open("input.csv", "r") as csvfile:
rows = csv.DictReader(csvfile)
sorted_rows = sorted(rows, key=lambda row: row["name"])
grouping = groupby(sorted_rows, lambda row: row["name"])
with open("output.csv", "w") as final_csvfile:
fieldnames = ["name", "number"]
writer = csv.DictWriter(final_csvfile, fieldnames=fieldnames)
for group, items in grouping:
total = sum(int(item["number"]) for item in items)
writer.writerow(
{
"name": group,
"number": str(total),
}
)
if __name__ == "__main__":
main()
This works well on a not too large number of rows, but when I run the actual CSV with 50 million rows, it becomes very slow and the program gets killed eventually.
Now the line: sorted_rows = sorted(rows, key=lambda row: row["name"])
is the main problem because it loads the 50 million rows into memory (a list) so it can be sorted. I have come to understand that the first thing sorted()
does is to convert any generator given to it into a list, so how do I go about this please? Any pointers?
@python_user the problem with the approach above is that it will keep appending to the dictionary and before you know it, the dictionary will become really large and could mess things up memory-wise.
@Bharel said something about external sorting in the comments, I looked into it and found a way.
I discovered that the UNIX sort command can perform an external merge sort on really large files, so I wrote a script to sort the really large CSV file and then passed the sorted CSV file into the python code in the question. That way, nothing too big gets written to memory.
Here's the code:
sort.sh
echo "sorting CSV"
sort input.csv -o input.csv
echo "Done!"
After the above script is run, then the sorted CSV is passed into the program:
import csv
from itertools import groupby
def main():
with open("input.csv", "r") as csvfile:
rows = csv.DictReader(csvfile)
grouping = groupby(rows, lambda row: row["name"])
with open("output.csv", "w") as final_csvfile:
fieldnames = ["name", "number"]
writer = csv.DictWriter(final_csvfile, fieldnames=fieldnames)
for group, items in grouping:
total = sum(int(item["number"]) for item in items)
writer.writerow(
{
"name": group,
"number": str(total),
}
)
if __name__ == "__main__":
main()
Note that the line using sorted() in the question is gone. I think this is a more efficient solution.