Search code examples
pythonapache-sparkpyspark

load numpy array into pyspark


I have an input json file where each row is an ID and a corresponding numpy array stored in base64. How can I load this file into pyspark?

I have tried creating a udf to do this:

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, DoubleType
import base64 

def decode_base64_and_convert_to_numpy(base64_string):
    decoded_bytes = base64.b64decode(base64_string)
    decoded_str = decoded_bytes.decode('utf-8')
    decoded_list = json.loads(decoded_str)
    return np.array(decoded_list)
    
decode_udf = udf(decode_base64_and_convert_to_numpy, ArrayType(DoubleType()))

but when I invoke it I get an encoding error:

numpy_loaded_embeddings = raw_input.withColumn('numpy_embedding', decode_udf('model_output'))
An error was encountered:

  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "<stdin>", line 7, in decode_base64_and_convert_to_numpy
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x93 in position 0: invalid start byte

Traceback (most recent call last):
  File "/mnt/yarn/usercache/livy/appcache/application_1705979808797_0003/container_1705979808797_0003_01_000001/pyspark.zip/pyspark/sql/dataframe.py", line 607, in show
    print(self._jdf.showString(n, 20, vertical))
  File "/mnt/yarn/usercache/livy/appcache/application_1705979808797_0003/container_1705979808797_0003_01_000001/py4j-0.10.9.5-src.zip/py4j/java_gateway.py", line 1322, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/mnt/yarn/usercache/livy/appcache/application_1705979808797_0003/container_1705979808797_0003_01_000001/pyspark.zip/pyspark/sql/utils.py", line 196, in deco
    raise converted from None
pyspark.sql.utils.PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "<stdin>", line 7, in decode_base64_and_convert_to_numpy
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x93 in position 0: invalid start byte

Solution

  • def decode_base64_and_convert_to_numpy(base64_string):
        decoded_bytes = base64.b64decode(base64_string)
        numpy_array = np.load(io.BytesIO(decoded_bytes))
        return numpy_array.tolist()
    

    Here there are an assumption comes in play regarding base64 string being a serialized numpy array. Also pickle was not used in dumping the data if yes then in np.load add a flag allow_pickle=True