Search code examples
pythonamazon-s3python-s3fs

Why do I get ConnectionResetError when reading and writing from and to s3 using smart_open?


The following code can read and write back to s3 on the fly following the the discussion on here:

from smart_open import open
import os

bucket_dir = "s3://my-bucket/annotations/"

with open(os.path.join(bucket_dir, "in.tsv.gz"), "rb") as fin:
    with open(
        os.path.join(bucket_dir, "out.tsv.gz"), "wb"
    ) as fout:
        for line in fin:
            l = [i.strip() for i in line.decode().split("\t")]
            string = "\t".join(l) + "\n"
            fout.write(string.encode())    

The issue is that after a few thousands lines processed (a few minutes) I get a "connection reset by peer" error:

    raise ProtocolError("Connection broken: %r" % e, e)
urllib3.exceptions.ProtocolError: ("Connection broken: ConnectionResetError(104, 'Connection reset by peer')", ConnectionResetError(104, 'Connection reset by peer'))

What can I do? I tried to fout.flush() after every fout.write(string.encode()) but it doesn't work well. Is there a better solution to approach to process a .tsv file with about 200 million lines?


Solution

  • I implemented some producer-consumer approach on top of smart_open. This mitigates the Connection broke error, but doesn't resolve it completely in some cases.

    class Producer:
        def __init__(self, queue, bucket_dir, input_file):
            self.queue = queue
            self.bucket_dir = bucket_dir
            self.input_file = input_file
    
        def run(self):
            with open(os.path.join(self.bucket_dir, self.input_file), "rb") as fin:
                for line in tqdm(fin):
                    while self.queue.full():
                        time.sleep(0.05)
                    self.queue.put(line_to_write)
            self.queue.put("DONE")
    
    
    class Consumer:
        def __init__(self, queue, bucket_dir, output_file):
            self.queue = queue
            self.bucket_dir = bucket_dir
            self.output_file = output_file
    
        def run(self):
            done = False
            to_write = ""
            count = 0
            with open(os.path.join(self.bucket_dir, self.output_file), "wb") as fout:
                while True:
                    while self.queue.empty():
                        time.sleep(0.05)
                    item = self.queue.get()
                    if item == "DONE":
                        fout.write(to_write)
                        fout.flush()
                        self.queue.task_done()
                        return
    
                    count += 1
                    to_write += item
                    if count % 256 == 0:  # batch write
                        fout.write(to_write.encode())
                        fout.flush()
    
    
    def main(args):
        q = Queue(1024)
    
        producer = Producer(q, args.bucket_dir, args.input_file)
        producer_thread = threading.Thread(target=producer.run)
    
        consumer = Consumer(q, args.bucket_dir, args.output_file)
        consumer_thread = threading.Thread(target=consumer.run)
    
        producer_thread.start()
        consumer_thread.start()
    
        producer_thread.join()
        consumer_thread.join()
        q.join()