Is anyone using Numba with PySpark?
I’m just starting to experiment with this - it seems that Numba can be used in PySpark pandas UDFs, for example, as long as the series is adapted to an ndarray
: numba-pyspark-udfs/udf.py at main · gmarkall/numba-pyspark-udfs · GitHub - however, I noticed that PySpark called the Numba function multiple times elementwise, in my naive first implementation.
We use numba extensively for simulations using pyspark, though quite differently than the way you’re doing it with pandas and udfs.
If you’re interested in a mapPartitions-style example with numba, let me know and I’ll make a little demo.
1 Like
Thanks for the reply - if you’re able to and it’s not too much trouble, a small demo of the way you’re using it would be great to see - many thanks in advance 
Here’s a mini-demo, let me know if you have questions
from numba import njit
import numpy
from functools import partial
df = spark.sparkContext.parallelize([(i, i * i) for i in range(10)]).toDF(['i', 'i_squared']).cache()
print('original dataframe')
df.show()
# convert the dataframe to a dict[string, ndarray]
# this is inefficient, just an example
def columns_to_ndarray(partition_number, iter):
# this function runs in plain-python space on the executor
res = {}
for row in iter:
row = row.asDict()
if not res:
res = {key: [] for key, val in row.items()}
[res[key].append(val) for key, val in row.items()]
yield {key: numpy.asarray(lst) for key, lst in res.items()}
ndarray_rdd = df.rdd.mapPartitionsWithIndex(columns_to_ndarray).cache()
print('converted to dict[column name, ndarray]')
for partition, data in enumerate(ndarray_rdd.collect()):
print('partition=', partition, 'data=', data)
@njit
def numba_invert(input_array):
result = numpy.empty_like(input_array)
for i in range(len(input_array)):
result[i] = input_array[i] * -1
return result
def invert_column(column_name, partition_number, iter):
for item in iter: # loop over the dictionaries, one per partition
item[column_name] = numba_invert(item[column_name])
yield item
i_plus_one_rdd = ndarray_rdd.mapPartitionsWithIndex(partial(invert_column, 'i')).cache()
print('numpy inverted column i')
for partition, data in enumerate(i_plus_one_rdd.collect()):
print('partition=', partition, 'data=', data)