Vectorized Python User-defined Table Functions (UDTFs)#
Spark 4.1 introduces the Vectorized Python user-defined table function (UDTF), a new type of user-defined table-valued function.
It can be used via the @arrow_udtf
decorator.
Unlike scalar functions that return a single result value from each call, each UDTF is invoked in
the FROM
clause of a query and returns an entire table as output.
Unlike the traditional Python UDTF that evaluates row by row, the Vectorized Python UDTF lets you directly operate on top of Apache Arrow arrays and column batches.
This allows you to leverage vectorized operations and improve the performance of your UDTF.
Vectorized Python UDTF Interface#
class NameYourArrowPythonUDTF:
def __init__(self) -> None:
"""
Initializes the user-defined table function (UDTF). This is optional.
This method serves as the default constructor and is called once when the
UDTF is instantiated on the executor side.
Any class fields assigned in this method will be available for subsequent
calls to the `eval`, `terminate` and `cleanup` methods.
Notes
-----
- You cannot create or reference the Spark session within the UDTF. Any
attempt to do so will result in a serialization error.
"""
...
def eval(self, *args: Any) -> Iterator[pa.RecordBatch | pa.Table]:
"""
Evaluates the function using the given input arguments.
This method is required and must be implemented.
Argument Mapping:
- Each provided scalar expression maps to exactly one value in the
`*args` list with type `pa.Array`.
- Each provided table argument maps to a `pa.RecordBatch` object containing
the columns in the order they appear in the provided input table,
and with the names computed by the query analyzer.
This method is called on every batch of input rows, and can produce zero or more
output pyarrow record batches or pyarrow tables. Each element in the output tuple
corresponds to one column specified in the return type of the UDTF.
Parameters
----------
*args : Any
Arbitrary positional arguments representing the input to the UDTF.
Yields
------
iterator
An iterator of `pa.RecordBatch` or `pa.Table` objects representing a batch of rows
in the UDTF result table. Yield as many times as needed to produce multiple batches.
Notes
-----
- UDTFs can instead accept keyword arguments during the function call if needed.
- The `eval` method can raise a `SkipRestOfInputTableException` to indicate that the
UDTF wants to skip consuming all remaining rows from the current partition of the
input table. This will cause the UDTF to proceed directly to the `terminate` method.
- The `eval` method can raise any other exception to indicate that the UDTF should be
aborted entirely. This will cause the UDTF to skip the `terminate` method and proceed
directly to the `cleanup` method, and then the exception will be propagated to the
query processor causing the invoking query to fail.
Examples
--------
This `eval` method takes a table argument and returns an arrow record batch for each input batch.
>>> def eval(self, batch: pa.RecordBatch):
... yield batch
This `eval` method takes a table argument and returns a pyarrow table for each input batch.
>>> def eval(self, batch: pa.RecordBatch):
... yield pa.table({"x": batch.column(0), "y": batch.column(1)})
This `eval` method takes both table and scalar arguments and returns a pyarrow table for each input batch.
>>> def eval(self, batch: pa.RecordBatch, x: pa.Array):
... yield pa.table({"x": x, "y": batch.column(0)})
"""
...
def terminate(self) -> Iterator[pa.RecordBatch | pa.Table]:
"""
Called when the UDTF has successfully processed all input rows.
This method is optional to implement and is useful for performing any
finalization operations after the UDTF has finished processing
all rows. It can also be used to yield additional rows if needed.
Table functions that consume all rows in the entire input partition
and then compute and return the entire output table can do so from
this method as well (please be mindful of memory usage when doing
this).
If any exceptions occur during input row processing, this method
won't be called.
Yields
------
iterator
An iterator of `pa.RecordBatch` or `pa.Table` objects representing a batch of rows
in the UDTF result table. Yield as many times as needed to produce multiple batches.
Examples
--------
>>> def terminate(self) -> Iterator[pa.RecordBatch | pa.Table]:
>>> yield pa.table({"x": pa.array([1, 2, 3])})
"""
...
def cleanup(self) -> None:
"""
Invoked after the UDTF completes processing input rows.
This method is optional to implement and is useful for final cleanup
regardless of whether the UDTF processed all input rows successfully
or was aborted due to exceptions.
Examples
--------
>>> def cleanup(self) -> None:
>>> self.conn.close()
"""
...
Defining the Output Schema#
The return type of the UDTF defines the schema of the table it outputs.
You can specify it in the @arrow_udtf
decorator.
It must be either a StructType
:
@arrow_udtf(returnType=StructType().add("c1", StringType()).add("c2", IntegerType()))
class YourArrowPythonUDTF:
...
or a DDL string representing a struct type:
@arrow_udtf(returnType="c1 string, c2 int")
class YourArrowPythonUDTF:
...
Emitting Output Rows#
The eval and terminate methods then emit zero or more output batches conforming to this schema by
yielding pa.RecordBatch
or pa.Table
objects.
@arrow_udtf(returnType="c1 int, c2 int")
class YourArrowPythonUDTF:
def eval(self, batch: pa.RecordBatch):
yield pa.table({"c1": batch.column(0), "c2": batch.column(1)})
You can also yield multiple pyarrow tables in the eval method.
@arrow_udtf(returnType="c1 int")
class YourArrowPythonUDTF:
def eval(self, batch: pa.RecordBatch):
yield pa.table({"c1": batch.column(0)})
yield pa.table({"c1": batch.column(1)})
You can also yield multiple pyarrow record batches in the eval method.
@arrow_udtf(returnType="c1 int")
class YourArrowPythonUDTF:
def eval(self, batch: pa.RecordBatch):
new_batch = pa.record_batch(
{"c1": batch.column(0).slice(0, len(batch) // 2)})
yield new_batch
Usage Examples#
Here’s how to use these UDTFs in DataFrame:
import pyarrow as pa
from pyspark.sql.functions import arrow_udtf
@arrow_udtf(returnType="c1 string")
class MyArrowPythonUDTF:
def eval(self, batch: pa.RecordBatch):
yield pa.table({"c1": batch.column("value")})
df = spark.range(10).selectExpr("id", "cast(id as string) as value")
MyArrowPythonUDTF(df.asTable()).show()
# Register the UDTF
spark.udtf.register("my_arrow_udtf", MyArrowPythonUDTF)
# Use in SQL queries
df = spark.sql("""
SELECT * FROM my_arrow_udtf(TABLE(SELECT id, cast(id as string) as value FROM range(10)))
""")