Python Data Source API#
Overview#
The Python Data Source API is a new feature introduced in Spark 4.0, enabling developers to read from custom data sources and write to custom data sinks in Python. This guide provides a comprehensive overview of the API and instructions on how to create, use, and manage Python data sources.
Simple Example#
Here’s a simple Python data source that generates exactly two rows of synthetic data. This example demonstrates how to set up a custom data source without using external libraries, focusing on the essentials needed to get it up and running quickly.
Step 1: Define the data source
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
class SimpleDataSource(DataSource):
"""
A simple data source for PySpark that generates exactly two rows of synthetic data.
"""
@classmethod
def name(cls):
return "simple"
def schema(self):
return StructType([
StructField("name", StringType()),
StructField("age", IntegerType())
])
def reader(self, schema: StructType):
return SimpleDataSourceReader()
class SimpleDataSourceReader(DataSourceReader):
def read(self, partition):
yield ("Alice", 20)
yield ("Bob", 30)
Step 2: Register the data source
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
spark.dataSource.register(SimpleDataSource)
Step 3: Read from the data source
spark.read.format("simple").load().show()
# +-----+---+
# | name|age|
# +-----+---+
# |Alice| 20|
# | Bob| 30|
# +-----+---+
Creating a Python Data Source#
To create a custom Python data source, you’ll need to subclass the DataSource
base classes and implement the necessary methods for reading and writing data.
This example demonstrates creating a simple data source to generate synthetic data using the faker library. Ensure the faker library is installed and accessible in your Python environment.
Define the Data Source
Start by creating a new subclass of DataSource
with the source name, schema.
In order to be used as source or sink in batch or streaming query, corresponding method of DataSource needs to be implemented.
Method that needs to be implemented for a capability:
source |
sink |
|
---|---|---|
batch |
reader() |
writer() |
streaming |
streamReader() or simpleStreamReader() |
streamWriter() |
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType
class FakeDataSource(DataSource):
"""
A fake data source for PySpark to generate synthetic data using the `faker` library.
Options:
- numRows: specify number of rows to generate. Default value is 3.
"""
@classmethod
def name(cls):
return "fake"
def schema(self):
return "name string, date string, zipcode string, state string"
def reader(self, schema: StructType):
return FakeDataSourceReader(schema, self.options)
def writer(self, schema: StructType, overwrite: bool):
return FakeDataSourceWriter(self.options)
def streamReader(self, schema: StructType):
return FakeStreamReader(schema, self.options)
# Please skip the implementation of this method if streamReader has been implemented.
def simpleStreamReader(self, schema: StructType):
return SimpleStreamReader()
def streamWriter(self, schema: StructType, overwrite: bool):
return FakeStreamWriter(self.options)
Implementing Batch Reader and Writer for Python Data Source#
Implement the Reader
Define the reader logic to generate synthetic data. Use the faker library to populate each field in the schema.
class FakeDataSourceReader(DataSourceReader):
def __init__(self, schema, options):
self.schema: StructType = schema
self.options = options
def read(self, partition):
from faker import Faker
fake = Faker()
# Note: every value in this `self.options` dictionary is a string.
num_rows = int(self.options.get("numRows", 3))
for _ in range(num_rows):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
row.append(value)
yield tuple(row)
Implement the Writer
Create a fake data source writer that processes each partition of data, counts the rows, and either prints the total count of rows after a successful write or the number of failed tasks if the writing process fails.
from dataclasses import dataclass
from typing import Iterator, List
from pyspark.sql.types import Row
from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage
@dataclass
class SimpleCommitMessage(WriterCommitMessage):
partition_id: int
count: int
class FakeDataSourceWriter(DataSourceWriter):
def write(self, rows: Iterator[Row]) -> SimpleCommitMessage:
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
cnt = sum(1 for _ in rows)
return SimpleCommitMessage(partition_id=partition_id, count=cnt)
def commit(self, messages: List[SimpleCommitMessage]) -> None:
total_count = sum(message.count for message in messages)
print(f"Total number of rows: {total_count}")
def abort(self, messages: List[SimpleCommitMessage]) -> None:
failed_count = sum(message is None for message in messages)
print(f"Number of failed tasks: {failed_count}")
Implementing Streaming Reader and Writer for Python Data Source#
Implement the Stream Reader
This is a dummy streaming data reader that generate 2 rows in every microbatch. The streamReader instance has a integer offset that increase by 2 in every microbatch.
class RangePartition(InputPartition):
def __init__(self, start, end):
self.start = start
self.end = end
class FakeStreamReader(DataSourceStreamReader):
def __init__(self, schema, options):
self.current = 0
def initialOffset(self) -> dict:
"""
Return the initial start offset of the reader.
"""
return {"offset": 0}
def latestOffset(self) -> dict:
"""
Return the current latest offset that the next microbatch will read to.
"""
self.current += 2
return {"offset": self.current}
def partitions(self, start: dict, end: dict):
"""
Plans the partitioning of the current microbatch defined by start and end offset,
it needs to return a sequence of :class:`InputPartition` object.
"""
return [RangePartition(start["offset"], end["offset"])]
def commit(self, end: dict):
"""
This is invoked when the query has finished processing data before end offset, this can be used to clean up resource.
"""
pass
def read(self, partition) -> Iterator[Tuple]:
"""
Takes a partition as an input and read an iterator of tuples from the data source.
"""
start, end = partition.start, partition.end
for i in range(start, end):
yield (i, str(i))
Implement the Simple Stream Reader
If the data source has low throughput and doesn’t require partitioning, you can implement SimpleDataSourceStreamReader instead of DataSourceStreamReader.
One of simpleStreamReader() and streamReader() must be implemented for readable streaming data source. And simpleStreamReader() will only be invoked when streamReader() is not implemented.
This is the same dummy streaming reader that generate 2 rows every batch implemented with SimpleDataSourceStreamReader interface.
class SimpleStreamReader(SimpleDataSourceStreamReader):
def initialOffset(self):
"""
Return the initial start offset of the reader.
"""
return {"offset": 0}
def read(self, start: dict) -> (Iterator[Tuple], dict):
"""
Takes start offset as an input, return an iterator of tuples and the start offset of next read.
"""
start_idx = start["offset"]
it = iter([(i,) for i in range(start_idx, start_idx + 2)])
return (it, {"offset": start_idx + 2})
def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
"""
Takes start and end offset as input and read an iterator of data deterministically.
This is called whe query replay batches during restart or after failure.
"""
start_idx = start["offset"]
end_idx = end["offset"]
return iter([(i,) for i in range(start_idx, end_idx)])
def commit(self, end):
"""
This is invoked when the query has finished processing data before end offset, this can be used to clean up resource.
"""
pass
Implement the Stream Writer
This is a streaming data writer that write the metadata information of each microbatch to a local path.
class SimpleCommitMessage(WriterCommitMessage):
partition_id: int
count: int
class FakeStreamWriter(DataSourceStreamWriter):
def __init__(self, options):
self.options = options
self.path = self.options.get("path")
assert self.path is not None
def write(self, iterator):
"""
Write the data and return the commit message of that partition
"""
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
cnt = 0
for row in iterator:
cnt += 1
return SimpleCommitMessage(partition_id=partition_id, count=cnt)
def commit(self, messages, batchId) -> None:
"""
Receives a sequence of :class:`WriterCommitMessage` when all write tasks succeed and decides what to do with it.
In this FakeStreamWriter, we write the metadata of the microbatch(number of rows and partitions) into a json file inside commit().
"""
status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
file.write(json.dumps(status) + "\n")
def abort(self, messages, batchId) -> None:
"""
Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some tasks fail and decides what to do with it.
In this FakeStreamWriter, we write a failure message into a txt file inside abort().
"""
with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
file.write(f"failed in batch {batchId}")
Serialization Requirement#
User defined DataSource, DataSourceReader, DataSourceWriter, DataSourceStreamReader and DataSourceStreamWriter and their methods must be able to be serialized by pickle.
For library that are used inside a method, it must be imported inside the method. For example, TaskContext must be imported inside the read() method in the code below.
def read(self, partition):
from pyspark import TaskContext
context = TaskContext.get()
Using a Python Data Source#
Use a Python Data Source in Batch Query
After defining your data source, it must be registered before usage.
spark.dataSource.register(FakeDataSource)
Read From a Python Data Source
Read from the fake datasource with the default schema and options:
spark.read.format("fake").load().show()
# +-----------+----------+-------+-------+
# | name| date|zipcode| state|
# +-----------+----------+-------+-------+
# |Carlos Cobb|2018-07-15| 73003|Indiana|
# | Eric Scott|1991-08-22| 10085| Idaho|
# | Amy Martin|1988-10-28| 68076| Oregon|
# +-----------+----------+-------+-------+
Read from the fake datasource with a custom schema:
spark.read.format("fake").schema("name string, company string").load().show()
# +---------------------+--------------+
# |name |company |
# +---------------------+--------------+
# |Tanner Brennan |Adams Group |
# |Leslie Maxwell |Santiago Group|
# |Mrs. Jacqueline Brown|Maynard Inc |
# +---------------------+--------------+
Read from the fake datasource with a different number of rows:
spark.read.format("fake").option("numRows", 5).load().show()
# +--------------+----------+-------+------------+
# | name| date|zipcode| state|
# +--------------+----------+-------+------------+
# | Pam Mitchell|1988-10-20| 23788| Tennessee|
# |Melissa Turner|1996-06-14| 30851| Nevada|
# | Brian Ramsey|2021-08-21| 55277| Washington|
# | Caitlin Reed|1983-06-22| 89813|Pennsylvania|
# | Douglas James|2007-01-18| 46226| Alabama|
# +--------------+----------+-------+------------+
Write To a Python Data Source
To write data to a custom location, make sure that you specify the mode() clause. Supported modes are append and overwrite.
df = spark.range(0, 10, 1, 5)
df.write.format("fake").mode("append").save()
# You can check the Spark log (standard error) to see the output of the write operation.
# Total number of rows: 10
Use a Python Data Source in Streaming Query
Once we register the python data source, we can also use it in streaming queries as source of readStream() or sink of writeStream() by passing short name or full name to format().
Start a query that read from fake python data source and write to console
query = spark.readStream.format("fake").load().writeStream.format("console").start()
# +---+
# | id|
# +---+
# | 0|
# | 1|
# +---+
# +---+
# | id|
# +---+
# | 2|
# | 3|
# +---+
We can also use the same data source in streaming reader and writer
query = spark.readStream.format("fake").load().writeStream.format("fake").start("/output_path")
Python Data Source Reader with direct Arrow Batch support for improved performance#
The Python Datasource Reader supports direct yielding of Arrow Batches, which can significantly improve data processing performance. By using the efficient Arrow format, this feature avoids the overhead of traditional row-by-row data processing, resulting in performance improvements of up to one order of magnitude, especially with large datasets.
Enabling Arrow Batch Support: To enable this feature, configure your custom DataSource to yield Arrow batches by returning pyarrow.RecordBatch objects within the read method of your DataSourceReader (or DataSourceStreamReader) implementation. This method simplifies data handling and reduces the number of I/O operations, particularly beneficial for large-scale data processing tasks.
Arrow Batch Example: The following example demonstrates how to implement a basic Data Source using Arrow Batch support.
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql import SparkSession
import pyarrow as pa
# Define the ArrowBatchDataSource
class ArrowBatchDataSource(DataSource):
"""
A Data Source for testing Arrow Batch Serialization
"""
@classmethod
def name(cls):
return "arrowbatch"
def schema(self):
return "key int, value string"
def reader(self, schema: str):
return ArrowBatchDataSourceReader(schema, self.options)
# Define the ArrowBatchDataSourceReader
class ArrowBatchDataSourceReader(DataSourceReader):
def __init__(self, schema, options):
self.schema: str = schema
self.options = options
def read(self, partition):
# Create Arrow Record Batch
keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
yield record_batch
def partitions(self):
# Define the number of partitions
num_part = 1
return [InputPartition(i) for i in range(num_part)]
# Initialize the Spark Session
spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate()
# Register the ArrowBatchDataSource
spark.dataSource.register(ArrowBatchDataSource)
# Load data using the custom data source
df = spark.read.format("arrowbatch").load()
df.show()
Usage Notes#
During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other Data Sources.