Search code examples
pythonapache-sparkpysparkparallel-processingvideo-processing

Reading multiple videos in parallel with PySpark


I need to write a PySpark script in order to read in multiple video from mp4 files parallelly and then process it in PySpark. (These 4 video stream represent multiple RTSP video stream which I will capture in the future).

My first approach was to use the multiprocessing library to read 4 video files in parallel. But this approach of mine generated many errors related to Spark such as "Only one SparkContext should be running in this JVM", "Java Heap Space (not enough memory)", etc.

So, my question is that are there any other approach to read multiple mp4 video files in parallel without using the multiprocessing library? I did a search on Google Bard lately and it said that I can parallelly read video stream directly into SparkContext RDD like this:

from pyspark import SparkContext
from pyspark.streaming import StreamingContext

def video_receiver(iterator):
    from cv2 import VideoCapture

    while True:
        video_path = iterator.next()
        cap = VideoCapture(video_path)
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            yield frame

sc = SparkContext()
ssc = StreamingContext(sc, batchDuration=batch_interval)

video_paths = ['video1.mp4', 'video2.mp4', 'video3.mp4', 'video4.mp4']
video_rdd = sc.parallelize(video_paths)
video_stream = video_rdd.mapPartitions(video_receiver)

I didn't run the script at the time of writing this question but I doubt this will work. If anyone has encountered the same issues or have any kind of solution, please help me on this.


Solution

  • I have adapted the following jupyter notebook to show how spark can do video processing at scale.

    https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/1969271421694072/3760413548916830/5612335034456173/latest.html

    You need to install python libraries in your conda environment. Also make sure you have ffmpeg library installed natively:

    pip install ffmpeg-python

    pip install face-recognition

    conda install -c conda-forge opencv

    Download a .mp4 video with face in it to perform face detection according to the following code.

    https://www.videezy.com/free-video/face?format-mp4=true

    Following the pyspark code :

    from pyspark import SQLContext, SparkConf, SparkContext
    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F
    
    
    conf = SparkConf().setAppName("myApp").setMaster("local[40]")
    spark = SparkSession.builder.master("local[40]").config("spark.driver.memory", "30g").getOrCreate()
    
    sc = spark.sparkContext
    sqlContext = SQLContext(sc)
    
    import cv2
    import os
    import uuid
    import ffmpeg
    import subprocess
    import numpy as np
    
    from scipy.optimize import linear_sum_assignment
    import pyspark.sql.functions as F
    from pyspark.sql import Row
    from pyspark.sql.types import (StructType, StructField,
                                   IntegerType, FloatType,
                                   ArrayType, BinaryType,
                                   MapType, DoubleType, StringType)
    
    from pyspark.sql.window import Window
    from pyspark.ml.feature import StringIndexer
    from pyspark.sql import Row, DataFrame, SparkSession
    
    import pathlib
    
    videos = []
    
    input_dir = "../data/video_files/faces/"
    
    pathlist = list(pathlib.Path(input_dir).glob('*.mp4'))
    
    pathlist = [Row(str(ele)) for ele in pathlist]
    print(pathlist)
    
    column_name = ["video_uri"]
    
    df = sqlContext.createDataFrame(data=pathlist, schema=column_name)
    
    print("Initial dataframe")
    df.show(10, truncate=False)
    
    video_metadata = StructType([
        StructField("width", IntegerType(), False),
        StructField("height", IntegerType(), False),
        StructField("num_frames", IntegerType(), False),
        StructField("duration", FloatType(), False)
    ])
    
    shots_schema = ArrayType(
        StructType([
            StructField("start", FloatType(), False),
            StructField("end", FloatType(), False)
        ]))
    
    
    @F.udf(returnType=video_metadata)
    def video_probe(uri):
        probe = ffmpeg.probe(uri, threads=1)
        video_stream = next(
            (
                stream
                for stream in probe["streams"]
                if stream["codec_type"] == "video"
            ),
            None,
        )
        width = int(video_stream["width"])
        height = int(video_stream["height"])
        num_frames = int(video_stream["nb_frames"])
        duration = float(video_stream["duration"])
        return (width, height, num_frames, duration)
    
    
    @F.udf(returnType=ArrayType(BinaryType()))
    def video2images(uri, width, height,
                     sample_rate: int = 5,
                     start: float = 0.0,
                     end: float = -1.0,
                     n_channels: int = 3):
        """
        Uses FFmpeg filters to extract image byte arrays
        and sampled & localized to a segment of video in time.
        """
        video_data, _ = (
            ffmpeg.input(uri, threads=1)
            .output(
                "pipe:",
                format="rawvideo",
                pix_fmt="rgb24",
                ss=start,
                t=end - start,
                r=1 / sample_rate,
            ).run(capture_stdout=True))
        img_size = height * width * n_channels
        return [video_data[idx:idx + img_size] for idx in range(0, len(video_data), img_size)]
    
    
    df = df.withColumn("metadata", video_probe(F.col("video_uri")))
    print("With Metadata")
    df.show(10, truncate=False)
    
    df = df.withColumn("frame", F.explode(
        video2images(F.col("video_uri"), F.col("metadata.width"), F.col("metadata.height"), F.lit(1), F.lit(0.0),
                     F.lit(5.0))))
    
    import face_recognition
    
    box_struct = StructType(
        [
            StructField("xmin", IntegerType(), False),
            StructField("ymin", IntegerType(), False),
            StructField("xmax", IntegerType(), False),
            StructField("ymax", IntegerType(), False)
        ]
    )
    
    
    def bbox_helper(bbox):
        top, right, bottom, left = bbox
        bbox = [top, left, bottom, right]
    
        return list(map(lambda x: max(x, 0), bbox))
    
    
    @F.udf(returnType=ArrayType(box_struct))
    def face_detector(img_data, width=1920, height=1080, n_channels=3):
        img = np.frombuffer(img_data, np.uint8).reshape(height, width, n_channels)
        faces = face_recognition.face_locations(img)
        return [bbox_helper(f) for f in faces]
    
    
    df = df.withColumn("faces", face_detector(F.col("frame"), F.col("metadata.width"), F.col("metadata.height")))
    
    annot_schema = ArrayType(
        StructType(
            [
                StructField("bbox", box_struct, False),
                StructField("tracker_id", StringType(), False),
            ]
        )
    )
    
    
    def bbox_iou(b1, b2):
        L = list(zip(b1, b2))
        left, top = np.max(L, axis=1)[:2]
        right, bottom = np.min(L, axis=1)[2:]
        if right < left or bottom < top:
            return 0
        b_area = lambda b: (b[2] - b[0]) * (b[3] - b[1])
        inter_area = b_area([left, top, right, bottom])
        b1_area, b2_area = b_area(b1), b_area(b2)
        iou = inter_area / float(b1_area + b2_area - inter_area)
        return iou
    
    
    @F.udf(returnType=MapType(IntegerType(), IntegerType()))
    def tracker_match(trackers, detections, bbox_col="bbox", threshold=0.3):
        """
        Match Bounding Boxes across successive image frames.
        Parameters
            ----------
            trackers : List of Box2dType with str identifier
                A column of tracked objects.
            detections: List of Box2dType without tracker id matching
                The list of unmatched detections.
            bbox_col: str
                    A string to name the column of bounding boxes.
            threshold : Float
                    IOU of Box2d objects exceeding threshold will be matched.
            Return
            ------
            MapType
                Returns a MapType matching indices of trackers and detections.
        """
        from scipy.optimize import linear_sum_assignment
    
        similarity = bbox_iou  # lambda a, b: a.iou(b)
        if not trackers or not detections:
            return {}
        if len(trackers) == len(detections) == 1:
            if (
                    similarity(trackers[0][bbox_col], detections[0][bbox_col])
                    >= threshold
            ):
                return {0: 0}
    
        sim_mat = np.array(
            [
                [
                    similarity(tracked[bbox_col], detection[bbox_col])
                    for tracked in trackers
                ]
                for detection in detections
            ],
            dtype=np.float32,
        )
    
        matched_idx = linear_sum_assignment(-sim_mat)
        matches = []
        for m in matched_idx:
            try:
                if sim_mat[m[0], m[1]] >= threshold:
                    matches.append(m.reshape(1, 2))
            except:
                pass
    
        if len(matches) == 0:
            return {}
        else:
            matches = np.concatenate(matches, axis=0, dtype=int)
    
        rows, cols = zip(*np.where(matches))
        idx_map = {cols[idx]: rows[idx] for idx in range(len(rows))}
        return idx_map
    
    
    @F.udf(returnType=ArrayType(box_struct))
    def OFMotionModel(frame, prev_frame, bboxes, height, width):
        if not prev_frame:
            prev_frame = frame
        gray = cv2.cvtColor(np.frombuffer(frame, np.uint8).reshape(height, width, 3), cv2.COLOR_BGR2GRAY)
        prev_gray = cv2.cvtColor(np.frombuffer(prev_frame, np.uint8).reshape(height, width, 3), cv2.COLOR_BGR2GRAY)
    
        inst = cv2.DISOpticalFlow.create(cv2.DISOPTICAL_FLOW_PRESET_MEDIUM)
        inst.setUseSpatialPropagation(False)
    
        flow = inst.calc(prev_gray, gray, None)
    
        h, w = flow.shape[:2]
        shifted_boxes = []
        for box in bboxes:
            xmin, ymin, xmax, ymax = box
            avg_y = np.mean(flow[int(ymin):int(ymax), int(xmin):int(xmax), 0])
            avg_x = np.mean(flow[int(ymin):int(ymax), int(xmin):int(xmax), 1])
    
            shifted_boxes.append(
                {"xmin": int(max(0, xmin + avg_x)), "ymin": int(max(0, ymin + avg_y)), "xmax": int(min(w, xmax + avg_x)),
                 "ymax": int(min(h, ymax + avg_y))})
        return shifted_boxes
    
    
    def match_annotations(iterator, segment_id="video_uri", id_col="tracker_id"):
        """
        Used by mapPartitions to iterate over the small chunks of our hierarchically-organized data.
        """
    
        matched_annots = []
        for idx, data in enumerate(iterator):
            data = data[1]
            if not idx:
                old_row = {idx: uuid.uuid4() for idx in range(len(data[1]))}
                old_row[segment_id] = data[0]
                pass
            annots = []
            curr_row = {segment_id: data[0]}
            if old_row[segment_id] != curr_row[segment_id]:
                old_row = {}
            if data[2] is not None:
                for ky, vl in data[2].items():
                    detection = data[1][vl].asDict()
                    detection[id_col] = old_row.get(ky, uuid.uuid4())
                    curr_row[vl] = detection[id_col]
                    annots.append(Row(**detection))
            matched_annots.append(annots)
            old_row = curr_row
        return matched_annots
    
    
    def track_detections(df, segment_id="video_uri", frames="frame", detections="faces", optical_flow=True):
        id_col = "tracker_id"
        frame_window = Window().orderBy(frames)
        value_window = Window().orderBy("value")
        annot_window = Window.partitionBy(segment_id).orderBy(segment_id, frames)
        indexer = StringIndexer(inputCol=segment_id, outputCol="vidIndex")
    
        # adjust detections w/ optical flow
        if optical_flow:
            df = (
                df.withColumn("prev_frames", F.lag(F.col(frames)).over(annot_window))
                .withColumn(detections, OFMotionModel(F.col(frames), F.col("prev_frames"), F.col(detections), F.col("metadata.height"), F.col("metadata.width")))
            )
    
        df = (
            df.select(segment_id, frames, detections)
            .withColumn("bbox", F.explode(detections))
            .withColumn(id_col, F.lit(""))
            .withColumn("trackables", F.struct([F.col("bbox"), F.col(id_col)]))
            .groupBy(segment_id, frames, detections)
            .agg(F.collect_list("trackables").alias("trackables"))
            .withColumn(
                "old_trackables", F.lag(F.col("trackables")).over(annot_window)
            )
            .withColumn(
                "matched",
                tracker_match(F.col("trackables"), F.col("old_trackables")),
            )
            .withColumn("frame_index", F.row_number().over(frame_window))
        )
    
        df = (
            indexer.fit(df)
            .transform(df)
            .withColumn("vidIndex", F.col("vidIndex").cast(StringType()))
        )
        unique_ids = df.select("vidIndex").distinct().count()
        matched = (
            df.select("vidIndex", segment_id, "trackables", "matched")
            .rdd.map(lambda x: (x[0], x[1:]))
            .partitionBy(unique_ids, lambda x: int(x[0]))
            .mapPartitions(match_annotations)
        )
        matched_annotations = sqlContext.createDataFrame(matched, annot_schema).withColumn("value_index",
                                                                                           F.row_number().over(
                                                                                               value_window))
    
        return (
            df.join(matched_annotations, F.col("value_index") == F.col("frame_index"))
            .withColumnRenamed("value", "trackers_matched")
            .withColumn("tracked", F.explode(F.col("trackers_matched")))
            .select(
                segment_id,
                frames,
                detections,
                F.col("tracked.{}".format("bbox")).alias("bbox"),
                F.col("tracked.{}".format(id_col)).alias(id_col),
            )
            .withColumn(id_col, F.sha2(F.concat(F.col(segment_id), F.col(id_col)), 256))
            .withColumn("tracked_detections", F.struct([F.col("bbox"), F.col(id_col)]))
            .groupBy(segment_id, frames, detections)
            .agg(F.collect_list("tracked_detections").alias("tracked_detections"))
            .orderBy(segment_id, frames, detections)
        )
    
    
    from pyspark import keyword_only
    from pyspark.ml.pipeline import Transformer
    from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param
    
    
    class DetectionTracker(Transformer, HasInputCol, HasOutputCol):
        """Detect and track."""
    
        @keyword_only
        def __init__(self, inputCol=None, outputCol=None, framesCol=None, detectionsCol=None, optical_flow=None):
            """Initialize."""
            super(DetectionTracker, self).__init__()
            self.framesCol = Param(self, "framesCol", "Column containing frames.")
            self.detectionsCol = Param(self, "detectionsCol", "Column containing detections.")
            self.optical_flow = Param(self, "optical_flow", "Use optical flow for tracker correction. Default is False")
            self._setDefault(framesCol="frame", detectionsCol="faces", optical_flow=False)
            kwargs = self._input_kwargs
            self.setParams(**kwargs)
    
        @keyword_only
        def setParams(self, inputCol=None, outputCol=None, framesCol=None, detectionsCol=None, optical_flow=None):
            """Get params."""
            kwargs = self._input_kwargs
            return self._set(**kwargs)
    
        def setFramesCol(self, value):
            """Set framesCol."""
            return self._set(framesCol=value)
    
        def getFramesCol(self):
            """Get framesCol."""
            return self.getOrDefault(self.framesCol)
    
        def setDetectionsCol(self, value):
            """Set detectionsCol."""
            return self._set(detectionsCol=value)
    
        def getDetectionsCol(self):
            """Get detectionsCol."""
            return self.getOrDefault(self.detectionsCol)
    
        def setOpticalflow(self, value):
            """Set optical_flow."""
            return self._set(optical_flow=value)
    
        def getOpticalflow(self):
            """Get optical_flow."""
            return self.getOrDefault(self.optical_flow)
    
        def _transform(self, dataframe):
            """Do transformation."""
            input_col = self.getInputCol()
            output_col = self.getOutputCol()
            frames_col = self.getFramesCol()
            detections_col = self.getDetectionsCol()
            optical_flow = self.getOpticalflow()
    
            id_col = "tracker_id"
            frame_window = Window().orderBy(frames_col)
            value_window = Window().orderBy("value")
            annot_window = Window.partitionBy(input_col).orderBy(input_col, frames_col)
            indexer = StringIndexer(inputCol=input_col, outputCol="vidIndex")
    
            # adjust detections w/ optical flow
            if optical_flow:
                dataframe = (
                    dataframe.withColumn("prev_frames", F.lag(F.col(frames_col)).over(annot_window))
                    .withColumn(detections_col,
                                OFMotionModel(F.col(frames_col), F.col("prev_frames"), F.col(detections_col)))
                )
    
            dataframe = (
                dataframe.select(input_col, frames_col, detections_col)
                .withColumn("bbox", F.explode(detections_col))
                .withColumn(id_col, F.lit(""))
                .withColumn("trackables", F.struct([F.col("bbox"), F.col(id_col)]))
                .groupBy(input_col, frames_col, detections_col)
                .agg(F.collect_list("trackables").alias("trackables"))
                .withColumn(
                    "old_trackables", F.lag(F.col("trackables")).over(annot_window)
                )
                .withColumn(
                    "matched",
                    tracker_match(F.col("trackables"), F.col("old_trackables")),
                )
                .withColumn("frame_index", F.row_number().over(frame_window))
            )
    
            dataframe = (
                indexer.fit(dataframe)
                .transform(dataframe)
                .withColumn("vidIndex", F.col("vidIndex").cast(StringType()))
            )
    
            unique_ids = dataframe.select("vidIndex").distinct().count()
            matched = (
                dataframe.select("vidIndex", input_col, "trackables", "matched")
                .rdd.map(lambda x: (x[0], x[1:]))
                .partitionBy(unique_ids, lambda x: int(x[0]))
                .mapPartitions(match_annotations)
            )
    
            matched_annotations = sqlContext.createDataFrame(matched, annot_schema).withColumn("value_index",
                                                                                               F.row_number().over(
                                                                                                   value_window))
    
            return (
                dataframe.join(matched_annotations, F.col("value_index") == F.col("frame_index"))
                .withColumnRenamed("value", "trackers_matched")
                .withColumn("tracked", F.explode(F.col("trackers_matched")))
                .select(
                    input_col,
                    frames_col,
                    detections_col,
                    F.col("tracked.{}".format("bbox")).alias("bbox"),
                    F.col("tracked.{}".format(id_col)).alias(id_col),
                )
                .withColumn(id_col, F.sha2(F.concat(F.col(input_col), F.col(id_col)), 256))
                .withColumn(output_col, F.struct([F.col("bbox"), F.col(id_col)]))
                .groupBy(input_col, frames_col, detections_col)
                .agg(F.collect_list(output_col).alias(output_col))
                .orderBy(input_col, frames_col, detections_col)
            )
    
    
    detectTracker = DetectionTracker(inputCol="video_uri", outputCol="tracked_detections")
    print(type(detectTracker))
    
    detectTracker.transform(df)
    final = track_detections(df)
    
    print("Final dataframe")
    final.select("tracked_detections").show(100, truncate=False)
    

    Output :

    [<Row('../data/video_files/faces/production_id_3761466 (2160p).mp4')>]
    Initial dataframe
    +-----------------------------------------------------------+
    |video_uri                                                  |
    +-----------------------------------------------------------+
    |../data/video_files/faces/production_id_3761466 (2160p).mp4|
    +-----------------------------------------------------------+
    
    With Metadata
    +-----------------------------------------------------------+------------------------+
    |video_uri                                                  |metadata                |
    +-----------------------------------------------------------+------------------------+
    |../data/video_files/faces/production_id_3761466 (2160p).mp4|{3840, 2160, 288, 11.52}|
    +-----------------------------------------------------------+------------------------+
    
    <class '__main__.DetectionTracker'>
    
    +---------------------------------------------------------------------------------------------+
    |tracked_detections                                                                           |
    +---------------------------------------------------------------------------------------------+
    |[{{649, 1810, 1204, 2160}, 56943f0cdeb96031c966fac39ef82dc8cc9761a5a2cf9cbf5740f9aeae842c17}]|
    |[{{678, 1777, 1233, 2160}, 56943f0cdeb96031c966fac39ef82dc8cc9761a5a2cf9cbf5740f9aeae842c17}]|
    |[{{725, 1774, 1280, 2160}, 56943f0cdeb96031c966fac39ef82dc8cc9761a5a2cf9cbf5740f9aeae842c17}]|
    |[{{728, 1760, 1283, 2160}, 56943f0cdeb96031c966fac39ef82dc8cc9761a5a2cf9cbf5740f9aeae842c17}]|
    +---------------------------------------------------------------------------------------------+