Search code examples
amazon-web-servicesapache-sparkamazon-s3boto3aws-cli

High volume read of many keys in an S3 bucket


I am trying to read and process the keys in the AWS S3 bucket/prefix s3://sorel-20m/09-DEC-2020/binaries/. This is a public bucket used for cyber security machine learning.

There are over 13 million binary files in this prefix.

Here are example keys in the data:

09-DEC-2020/binaries/0000029bfead495a003e43a7ab8406c6209ffb7d5e59dd212607aa358bfd66ea
09-DEC-2020/binaries/000003b99c3d4b9860ad0b0ca43450603e5322f2cca3c9b3d543a2d6440305a0
09-DEC-2020/binaries/00000533148c26bcc09ab44b1acafe32dde93773d4a7e3dbd06c8232db5e437f
...
09-DEC-2020/binaries/fffffac77abc5f22baefd850a753b0e32a8c106f983f84f6b83fb20df465c7ab
09-DEC-2020/binaries/fffffd86f00a5b4547d3b99963cae39781fa015b3f869b3e232858dd6011d062
09-DEC-2020/binaries/fffffee23b47f84cfdf25c43af7707c8ffa94a974e5af9659e3ed67e2e30b80b

Just listing the files with an AWS CLI command such as aws s3 ls takes hours.

I tried filtering using the CLI exclude and include parameters:

aws s3 cp s3://sorel-20m/09-DEC-2020/binaries/ . --recursive --dryrun --exclude '*' --include '0000029*'

This returned data quickly, but then did not complete. It seems that the CLI is reading the keys in alphabetical order, because when I search for keys at the end of an alphabetical sort (beginning with 'fff', the command below takes a long time to return data:

aws s3 cp s3://sorel-20m/09-DEC-2020/binaries/ . --recursive --dryrun --exclude '*' --include 'fff*'

I also tried the following AWS Glue (similar to Spark) script. This timed out after 1 hour:

from pyspark.context import SparkContext
from awsglue.context import GlueContext
from awsglue.job import Job

sc = SparkContext.getOrCreate()
glueContext = GlueContext(sc)
spark = glueContext.spark_session
job = Job(glueContext)

df = spark.read.format("binaryFile").option('pathGlobFilter', '0000029*').option("wholeFile","true").load("s3://sorel-20m/09-DEC-2020/binaries")
#print(df.count())

df.select('Path').write.csv('s3://my-bucket')
job.commit()

If I knew the smallest and largest value of the first character of each key I'm interested in, I could use boto3 per https://stackoverflow.com/a/52450389/11262633. This would allow me to launch parallel processes, each using a filter

bucket.object.filter(Prefix=f'09-DEC-2020/binaries/{first_char}

As you can see from the data above, the first character of the prefix 09-DEC-2020/binaries/ ranges from 0 to f.

So, I could launch 16 parallel processes, one for each character between 0 and f:

import boto3
import sys

session = boto3.Session() 
s3 = session.resource('s3')
bucket = s3.Bucket('sorel-20m')

# Assume this script is called with an argument between `0` and `f`
first_char = sys.args[0] 

prefix = f'09-DEC-2020/binaries/{first_char}' 

current_objects = bucket.objects.filter(Prefix=prefix)

...

All processing was done on EC2 so my local computer's network bandwidth is not a bottleneck here.

My question

Would you recommend this approach? If so, how could I ensure that all keys will start with a character between 0 and f?


Solution

  • The underlying API that S3 uses to enumerate the contents of a bucket, ListObjectsV2 returns at most 1000 items, and to paginate results, requires an opaque value from a previous call. This means it's inheritly not possible to get the results of multiple pages in parallel. However, when you have objects named with a predictable pattern as this bucket, then as you suggest, you can request multiple sections in parallel.

    The only thing I would suggest is to segment the parts you're requesting further than just the first character to allow more than 16 possible workers. Additionally, check the characters before and after the expected characters to verify that there are no objects outside of the expected ones.

    Putting it all together would look something like this:

    import boto3
    import ctypes
    import multiprocessing
    
    def validate_hex_digits(s3, bucket, prefix):
        # Make sure the first character under the prefix is some hex digit
        resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1)
        assert('0' <= resp['Contents'][0]['Key'][len(prefix)] <= 'f')
        # Make sure there's nothing after 'f'
        resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1, StartAfter=prefix+'g')
        assert('Contents' not in resp)
    
    def worker(queue, queue_out, jobs):
        # A single worker, pull a job from 'queue', send results to 'queue_out'
        # Use jobs value to track the number of jobs in flight
        s3 = boto3.client('s3')
        while True:
            with jobs.get_lock():
                if jobs.value == 0:
                    # Nothing left to do
                    queue_out.put(None)
                    break
                jobs.value -= 1
                bucket, prefix, token = queue.get()
            # Build up args for the call to list_objects
            args = {
                "Bucket": bucket,
                "Prefix": prefix,
            }
            if token is not None:
                args["ContinuationToken"] = token
            resp = s3.list_objects_v2(**args)
            if 'Contents' in resp:
                queue_out.put(resp['Contents'])
            if 'NextContinuationToken' in resp:
                # There's another page for this prefix, send it off
                # for the next available worker to pick up
                with jobs.get_lock():
                    queue.put((bucket, prefix, resp['NextContinuationToken']))
                    jobs.value += 1
    
    def main():
        bucket = 'sorel-20m'
        prefix = '09-DEC-2020/binaries/'
    
        s3 = boto3.client('s3')
    
        # Verify all of the objects are at least two digit hex digits under the prefix
        validate_hex_digits(s3, bucket, prefix)
        for i in range(16):
            validate_hex_digits(s3, bucket, prefix + f"{i:x}")
    
        # If we get here, all the keys follow the pattern we expect for at 
        # least two digits.  Go ahead and use multi processing to pull down 
        # the list of objects as fast as possible
    
        # A queue to store work items
        queue = multiprocessing.Queue()
        # A queue to get pages of results
        queue_out = multiprocessing.Queue()
        # How many jobs are left to process?
        jobs = multiprocessing.Value(ctypes.c_int, 0)
        # Place some seeds in the queue for the first two hex characters
        for i in range(256):
            queue.put((bucket, prefix + f"{i:02x}", None))
            jobs.value += 1
        # Create and start some worker threads, two per process 
        # to allow for network wait time
        workers = multiprocessing.cpu_count() * 2
        procs = []
        for _ in range(workers):
            proc = multiprocessing.Process(target=worker, args=(queue, queue_out, jobs))
            proc.start()
            procs.append(proc)
    
        # While the workers are working, pull down pages and do something with them
        while workers > 0:
            result = queue_out.get()
            if result is None:
                # Signal that a worker finished
                workers -= 1
            else:
                for cur in result:
                    # Just show the results like the AWS CLI does
                    print(f"{cur['LastModified'].strftime('%Y-%m-%d %H:%M:%S')} {cur['Size']:10d} {cur['Key'][len(prefix):]}")
    
        # Clean up
        for proc in procs:
            proc.join()
    
    if __name__ == "__main__":
        main()
    

    On my machine, this takes about 5 minutes to enumerate the objects in this bucket, compared to nearly an hour for the AWS CLI to do the same thing. It should be noted that the results will be in an arbitrary order. If this is a problem, some post-enumeration sorting will need to be done. Short of the bucket owner enabling and publishing a S3 Inventory Report, you probably won't be able to get a list of objects much faster.