Is anyone using Numba with PySpark?
I’m just starting to experiment with this - it seems that Numba can be used in PySpark pandas UDFs, for example, as long as the series is adapted to an ndarray
: numba-pyspark-udfs/udf.py at main · gmarkall/numba-pyspark-udfs · GitHub - however, I noticed that PySpark called the Numba function multiple times elementwise, in my naive first implementation.
1 Like
We use numba extensively for simulations using pyspark, though quite differently than the way you’re doing it with pandas and udfs.
If you’re interested in a mapPartitions-style example with numba, let me know and I’ll make a little demo.
1 Like
Thanks for the reply - if you’re able to and it’s not too much trouble, a small demo of the way you’re using it would be great to see - many thanks in advance
Here’s a mini-demo, let me know if you have questions
from numba import njit
import numpy
from functools import partial
df = spark.sparkContext.parallelize([(i, i * i) for i in range(10)]).toDF(['i', 'i_squared']).cache()
print('original dataframe')
df.show()
# convert the dataframe to a dict[string, ndarray]
# this is inefficient, just an example
def columns_to_ndarray(partition_number, iter):
# this function runs in plain-python space on the executor
res = {}
for row in iter:
row = row.asDict()
if not res:
res = {key: [] for key, val in row.items()}
[res[key].append(val) for key, val in row.items()]
yield {key: numpy.asarray(lst) for key, lst in res.items()}
ndarray_rdd = df.rdd.mapPartitionsWithIndex(columns_to_ndarray).cache()
print('converted to dict[column name, ndarray]')
for partition, data in enumerate(ndarray_rdd.collect()):
print('partition=', partition, 'data=', data)
@njit
def numba_invert(input_array):
result = numpy.empty_like(input_array)
for i in range(len(input_array)):
result[i] = input_array[i] * -1
return result
def invert_column(column_name, partition_number, iter):
for item in iter: # loop over the dictionaries, one per partition
item[column_name] = numba_invert(item[column_name])
yield item
i_plus_one_rdd = ndarray_rdd.mapPartitionsWithIndex(partial(invert_column, 'i')).cache()
print('numpy inverted column i')
for partition, data in enumerate(i_plus_one_rdd.collect()):
print('partition=', partition, 'data=', data)