When connecting to a databricks cluster with databricks-connect, I get a Py4JJavaError
exception when I do a repartition
on a simple rdd:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
rdd = spark.sparkContext.parallelize(range(0, 10), 3)
print(rdd.sum())
print(rdd.repartition(5).sum())
The first print statement gets executed fine and prints 45
, but the second print statement fails with the following error:
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
in
5 rdd = spark.sparkContext.parallelize(range(0, 10), 3)
6 print(rdd.sum())
----> 7 print(rdd.repartition(5).sum())
d:\source\repos\...\.venv_3_8_10\lib\site-packages\pyspark\rdd.py in sum(self)
1256 6.0
1257 """
-> 1258 return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
1259
1260 def count(self):
d:\source\repos\...\.venv_3_8_10\lib\site-packages\pyspark\rdd.py in fold(self, zeroValue, op)
1110 # zeroValue provided to each partition is unique from the one provided
1111 # to the final reduce call
-> 1112 vals = self.mapPartitions(func).collect()
1113 return reduce(op, vals, zeroValue)
1114
d:\source\repos\...\.venv_3_8_10\lib\site-packages\pyspark\rdd.py in collect(self)
964 # Default path used in OSS Spark / for non-credential passthrough clusters:
965 with SCCallSiteSync(self.context) as css:
--> 966 sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
967 return list(_load_from_socket(sock_info, self._jrdd_deserializer))
968
d:\source\repos\...\.venv_3_8_10\lib\site-packages\py4j\java_gateway.py in __call__(self, *args)
1302
1303 answer = self.gateway_client.send_command(command)
-> 1304 return_value = get_return_value(
1305 answer, self.gateway_client, self.target_id, self.name)
1306
d:\source\repos\...\.venv_3_8_10\lib\site-packages\pyspark\sql\utils.py in deco(*a, **kw)
115 def deco(*a, **kw):
116 try:
--> 117 return f(*a, **kw)
118 except py4j.protocol.Py4JJavaError as e:
119 converted = convert_exception(e.java_exception)
d:\source\repos\...\.venv_3_8_10\lib\site-packages\py4j\protocol.py in get_return_value(answer, gateway_client, target_id, name)
324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
325 if answer[1] == REFERENCE_TYPE:
--> 326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
328 format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: java.lang.ClassCastException: cannot assign instance of org.apache.spark.serializer.KryoSerializer to field org.apache.spark.ShuffleDependency.rowBasedChecksums of type [Lorg.apache.spark.shuffle.checksum.RowBasedChecksum; in instance of org.apache.spark.ShuffleDependency
at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2301)
at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1431)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2437)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
at sun.reflect.GeneratedMethodAccessor257.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2322)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
at sun.reflect.GeneratedMethodAccessor257.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2322)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
at sun.reflect.GeneratedMethodAccessor257.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2322)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
at sun.reflect.GeneratedMethodAccessor257.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2322)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.readArray(ObjectInputStream.java:2119)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1657)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
at org.apache.spark.sql.util.ProtoSerializer.$anonfun$deserializeObject$1(ProtoSerializer.scala:7058)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
at org.apache.spark.sql.util.ProtoSerializer.deserializeObject(ProtoSerializer.scala:7043)
at com.databricks.service.SparkServiceRPCHandler.execute0(SparkServiceRPCHandler.scala:728)
at com.databricks.service.SparkServiceRPCHandler.$anonfun$executeRPC0$1(SparkServiceRPCHandler.scala:477)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
at com.databricks.service.SparkServiceRPCHandler.executeRPC0(SparkServiceRPCHandler.scala:372)
at com.databricks.service.SparkServiceRPCHandler$$anon$2.call(SparkServiceRPCHandler.scala:323)
at com.databricks.service.SparkServiceRPCHandler$$anon$2.call(SparkServiceRPCHandler.scala:309)
at java.util.concurrent.FutureTask.run(FutureTask.java:266)
at com.databricks.service.SparkServiceRPCHandler.$anonfun$executeRPC$1(SparkServiceRPCHandler.scala:359)
at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
at com.databricks.service.SparkServiceRPCHandler.executeRPC(SparkServiceRPCHandler.scala:336)
at com.databricks.service.SparkServiceRPCServlet.doPost(SparkServiceRPCServer.scala:167)
at javax.servlet.http.HttpServlet.service(HttpServlet.java:523)
at javax.servlet.http.HttpServlet.service(HttpServlet.java:590)
at org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:799)
at org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:550)
at org.eclipse.jetty.server.handler.ScopedHandler.nextScope(ScopedHandler.java:190)
at org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:501)
at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)
at org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:127)
at org.eclipse.jetty.server.Server.handle(Server.java:516)
at org.eclipse.jetty.server.HttpChannel.lambda$handle$1(HttpChannel.java:388)
at org.eclipse.jetty.server.HttpChannel.dispatch(HttpChannel.java:633)
at org.eclipse.jetty.server.HttpChannel.handle(HttpChannel.java:380)
at org.eclipse.jetty.server.HttpConnection.onFillable(HttpConnection.java:277)
at org.eclipse.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:311)
at org.eclipse.jetty.io.FillInterest.fillable(FillInterest.java:105)
at org.eclipse.jetty.io.ChannelEndPoint$1.run(ChannelEndPoint.java:104)
at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.runTask(EatWhatYouKill.java:338)
at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.doProduce(EatWhatYouKill.java:315)
at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.tryProduce(EatWhatYouKill.java:173)
at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.run(EatWhatYouKill.java:131)
at org.eclipse.jetty.util.thread.ReservedThreadExecutor$ReservedThread.run(ReservedThreadExecutor.java:386)
at org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:883)
at org.eclipse.jetty.util.thread.QueuedThreadPool$Runner.run(QueuedThreadPool.java:1034)
at java.lang.Thread.run(Thread.java:750)
Instead of the error, I would like the second print statement to get executed and print 45
.
I can run the snippet fine in a databricks notebook in the webinterface.
I can do repartition
on dataframes fine without errors. The below snippet runs fine with databricks connect and prints 45
:
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.range(0, 10)
df.agg(F.sum("id")).collect()[0]["sum(id)"]
df_repar = df.repartition(5)
print(df_repar.agg(F.sum("id")).collect()[0]["sum(id)"])
I have a quite big setup using dataframes that runs without any apparent issues. I would like to move to databricks runtime 11.3 LTS but this issue is preventing me from upgrading.
I run python 3.8.10 and have asserted that version numbers of the packages on the cluster match the locally installed ones.
I run databricks-connect==10.4.22
and connect to a databricks cluster running databricks runtime 10.4 LTS.
This issue was resolved by upgrading to the newly released databricks-connect==10.4.25
.