I have a function that compares images from same folder, against themselves - with an output of a similarity prediction. The function runs fine in python but I want to leverage the power of pyspark parellelisation.
Here, I use Spark by simply parallelizing the list i.e. turn it into an RDD.
img_list = sc.parallelize(os.listdir(folder_dir))
f_img_list = img_list.filter(lambda f: f.endswith('.jpg') or f.endswith('.png'))
Defining the function:
def compare_images(x1,x2):
#Preprocess images
img_array1 = preprocess_image2(x1)
img_array2 = preprocess_image2(x2)
pred = compare(img_array1 , img_array2)
return pred
At this point I want to apply operations on the RDD with a requirement that the images in the folder should not compare against itself.
My attempt is to use "map" but I'm unsure on how to do that. Below is my attempt but this assumes only 1 argument:
prediction = f_img_list.map(compare_images)
prediction.collect()
I'm also aware that my attempt does not include the requirement that the images should not compare against each other - assistance with that will also be appreciated.
You could create a list of pairs of distinct image filenames, then parallelize that list, also modify your compare_images function to take a single argument instead of two.
edit: lets try to use the RDD's filter method to filter out the files which end with '.jpg' or '.png'
import os
import itertools
from pyspark import SparkContext
sc = SparkContext()
def preprocess_image2(image_path):
pass
def compare(img_array1, img_array2):
pass
img_list = sc.parallelize(os.listdir(folder_dir))
f_img_list = img_list.filter(lambda f: f.endswith('.jpg') or f.endswith('.png'))
f_img_list_local = f_img_list.collect()
image_pairs = list(itertools.combinations(f_img_list_local, 2))
image_pairs_rdd = sc.parallelize(image_pairs)
def compare_images(image_pair):
x1, x2 = image_pair
img_array1 = preprocess_image2(x1)
img_array2 = preprocess_image2(x2)
pred = compare(img_array1, img_array2)
return pred
predictions = image_pairs_rdd.map(compare_images)
results = predictions.collect()