Source code for pyspark.sql.plot.core

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import math

from typing import Any, TYPE_CHECKING, List, Optional, Union, Sequence
from types import ModuleType
from pyspark.errors import PySparkValueError
from pyspark.sql import Column, functions as F
from pyspark.sql.internal import InternalFunction as SF
from pyspark.sql.pandas.utils import require_minimum_pandas_version
from pyspark.sql.utils import NumpyHelper, require_minimum_plotly_version

if TYPE_CHECKING:
    from pyspark.sql import DataFrame, Row
    import pandas as pd
    from plotly.graph_objs import Figure


class PySparkTopNPlotBase:
    def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":
        max_rows = int(
            sdf._session.conf.get("spark.sql.pyspark.plotting.max_rows")  # type: ignore[arg-type]
        )
        pdf = sdf.limit(max_rows + 1).toPandas()

        self.partial = False
        if len(pdf) > max_rows:
            self.partial = True
            pdf = pdf.iloc[:max_rows]

        return pdf


class PySparkSampledPlotBase:
    def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
        from pyspark.sql import Observation, functions as F

        max_rows = int(
            sdf._session.conf.get("spark.sql.pyspark.plotting.max_rows")  # type: ignore[arg-type]
        )
        observation = Observation("pyspark plotting")

        rand_col_name = "__pyspark_plotting_sampled_plot_base_rand__"
        id_col_name = "__pyspark_plotting_sampled_plot_base_id__"

        sampled_sdf = (
            sdf.observe(observation, F.count(F.lit(1)).alias("count"))
            .select(
                "*",
                F.rand().alias(rand_col_name),
                F.monotonically_increasing_id().alias(id_col_name),
            )
            .sort(rand_col_name)
            .limit(max_rows + 1)
            .coalesce(1)
            .sortWithinPartitions(id_col_name)
            .drop(rand_col_name, id_col_name)
        )
        pdf = sampled_sdf.toPandas()

        if len(pdf) > max_rows:
            try:
                self.fraction = float(max_rows) / observation.get["count"]
            except Exception:
                pass
            return pdf[:max_rows]
        else:
            self.fraction = 1.0
            return pdf


class PySparkPlotAccessor:
    """
    Accessor for DataFrame plotting functionality in PySpark.

    Users can call the accessor as ``df.plot(kind="line")`` or use the dedicated
    methods like ``df.plot.line(...)`` to generate plots.
    """

    plot_data_map = {
        "area": PySparkSampledPlotBase().get_sampled,
        "bar": PySparkTopNPlotBase().get_top_n,
        "barh": PySparkTopNPlotBase().get_top_n,
        "line": PySparkSampledPlotBase().get_sampled,
        "pie": PySparkTopNPlotBase().get_top_n,
        "scatter": PySparkSampledPlotBase().get_sampled,
    }
    _backends = {}  # type: ignore[var-annotated]

    def __init__(self, data: "DataFrame"):
        self.data = data

    def __call__(
        self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any
    ) -> "Figure":
        plot_backend = PySparkPlotAccessor._get_plot_backend(backend)

        return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs)

    @staticmethod
    def _get_plot_backend(backend: Optional[str] = None) -> ModuleType:
        backend = backend or "plotly"

        if backend in PySparkPlotAccessor._backends:
            return PySparkPlotAccessor._backends[backend]

        if backend == "plotly":
            require_minimum_plotly_version()
        else:
            raise PySparkValueError(
                errorClass="UNSUPPORTED_PLOT_BACKEND",
                messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])},
            )
        from pyspark.sql.plot import plotly as module

        return module

[docs] def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": """ Plot DataFrame as lines. Parameters ---------- x : str Name of column to use for the horizontal axis. y : str or list of str Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. **kwargs : optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.line(x="category", y="int_val") # doctest: +SKIP >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP """ return self(kind="line", x=x, y=y, **kwargs)
[docs] def bar(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": """ Vertical bar plot. A bar plot is a plot that presents categorical data with rectangular bars with lengths proportional to the values that they represent. A bar plot shows comparisons among discrete categories. One axis of the plot shows the specific categories being compared, and the other axis represents a measured value. Parameters ---------- x : str Name of column to use for the horizontal axis. y : str or list of str Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. **kwargs : optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.bar(x="category", y="int_val") # doctest: +SKIP >>> df.plot.bar(x="category", y=["int_val", "float_val"]) # doctest: +SKIP """ return self(kind="bar", x=x, y=y, **kwargs)
[docs] def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": """ Make a horizontal bar plot. A horizontal bar plot is a plot that presents quantitative data with rectangular bars with lengths proportional to the values that they represent. A bar plot shows comparisons among discrete categories. One axis of the plot shows the specific categories being compared, and the other axis represents a measured value. Parameters ---------- x : str or list of str Name(s) of the column(s) to use for the horizontal axis. Multiple columns can be plotted. y : str or list of str Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. **kwargs : optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Notes ----- In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs. In Plotly, `x` refers to the values and `y` refers to the categories. In Matplotlib, `x` refers to the categories and `y` refers to the values. Ensure correct axis labeling based on the backend used. Examples -------- >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.barh(x="int_val", y="category") # doctest: +SKIP >>> df.plot.barh( ... x=["int_val", "float_val"], y="category" ... ) # doctest: +SKIP """ return self(kind="barh", x=x, y=y, **kwargs)
[docs] def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure": """ Create a scatter plot with varying marker point size and color. The coordinates of each point are defined by two dataframe columns and filled circles are used to represent each point. This kind of plot is useful to see complex correlations between two variables. Points could be for instance natural 2D coordinates like longitude and latitude in a map or, in general, any pair of metrics that can be plotted against each other. Parameters ---------- x : str Name of column to use as horizontal coordinates for each point. y : str or list of str Name of column to use as vertical coordinates for each point. **kwargs: Optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ['length', 'width', 'species'] >>> df = spark.createDataFrame(data, columns) >>> df.plot.scatter(x='length', y='width') # doctest: +SKIP """ return self(kind="scatter", x=x, y=y, **kwargs)
[docs] def area(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": """ Draw a stacked area plot. An area plot displays quantitative data visually. Parameters ---------- x : str Name of column to use for the horizontal axis. y : str or list of str Name(s) of the column(s) to plot. **kwargs: Optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> from datetime import datetime >>> data = [ ... (3, 5, 20, datetime(2018, 1, 31)), ... (2, 5, 42, datetime(2018, 2, 28)), ... (3, 6, 28, datetime(2018, 3, 31)), ... (9, 12, 62, datetime(2018, 4, 30)) ... ] >>> columns = ["sales", "signups", "visits", "date"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.area(x='date', y=['sales', 'signups', 'visits']) # doctest: +SKIP """ return self(kind="area", x=x, y=y, **kwargs)
[docs] def pie(self, x: str, y: Optional[str], **kwargs: Any) -> "Figure": """ Generate a pie plot. A pie plot is a proportional representation of the numerical data in a column. Parameters ---------- x : str Name of column to be used as the category labels for the pie plot. y : str, optional Name of the column to plot. If not provided, `subplots=True` must be passed at `kwargs`. **kwargs Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> from datetime import datetime >>> data = [ ... (3, 5, 20, datetime(2018, 1, 31)), ... (2, 5, 42, datetime(2018, 2, 28)), ... (3, 6, 28, datetime(2018, 3, 31)), ... (9, 12, 62, datetime(2018, 4, 30)) ... ] >>> columns = ["sales", "signups", "visits", "date"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.pie(x='date', y='sales') # doctest: +SKIP >>> df.plot.pie(x='date', subplots=True) # doctest: +SKIP """ return self(kind="pie", x=x, y=y, **kwargs)
[docs] def box(self, column: Optional[Union[str, List[str]]] = None, **kwargs: Any) -> "Figure": """ Make a box plot of the DataFrame columns. Make a box-and-whisker plot from DataFrame columns, optionally grouped by some other columns. A box plot is a method for graphically depicting groups of numerical data through their quartiles. The box extends from the Q1 to Q3 quartile values of the data, with a line at the median (Q2). The whiskers extend from the edges of box to show the range of the data. By default, they extend no more than 1.5 * IQR (IQR = Q3 - Q1) from the edges of the box, ending at the farthest data point within that interval. Outliers are plotted as separate dots. Parameters ---------- column: str or list of str, optional Column name or list of names to be used for creating the box plot. If None (default), all numeric columns will be used. **kwargs Extra arguments to `precision`: refer to a float that is used by pyspark to compute approximate statistics for building a boxplot. The default value is 0.01. Use smaller values to get more precise statistics. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [ ... ("A", 50, 55), ... ("B", 55, 60), ... ("C", 60, 65), ... ("D", 65, 70), ... ("E", 70, 75), ... ("F", 10, 15), ... ("G", 85, 90), ... ("H", 5, 150), ... ] >>> columns = ["student", "math_score", "english_score"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.box() # doctest: +SKIP >>> df.plot.box(column="math_score") # doctest: +SKIP >>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP """ return self(kind="box", column=column, **kwargs)
[docs] def kde( self, bw_method: Union[int, float], column: Optional[Union[str, List[str]]] = None, ind: Optional[Union[Sequence[float], int]] = None, **kwargs: Any, ) -> "Figure": """ Generate Kernel Density Estimate plot using Gaussian kernels. In statistics, kernel density estimation (KDE) is a non-parametric way to estimate the probability density function (PDF) of a random variable. This function uses Gaussian kernels and includes automatic bandwidth determination. Parameters ---------- bw_method : int or float The method used to calculate the estimator bandwidth. See KernelDensity in PySpark for more information. column: str or list of str, optional Column name or list of names to be used for creating the kde plot. If None (default), all numeric columns will be used. ind : List of float, NumPy array or integer, optional Evaluation points for the estimated PDF. If None (default), 1000 equally spaced points are used. If `ind` is a NumPy array, the KDE is evaluated at the points passed. If `ind` is an integer, `ind` number of equally spaced points are used. **kwargs : optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ["length", "width", "species"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.kde(bw_method=0.3, ind=100) # doctest: +SKIP >>> df.plot.kde(column=["length", "width"], bw_method=0.3, ind=100) # doctest: +SKIP >>> df.plot.kde(column="length", bw_method=0.3, ind=100) # doctest: +SKIP """ return self(kind="kde", column=column, bw_method=bw_method, ind=ind, **kwargs)
[docs] def hist( self, column: Optional[Union[str, List[str]]] = None, bins: int = 10, **kwargs: Any ) -> "Figure": """ Draw one histogram of the DataFrame’s columns. A `histogram`_ is a representation of the distribution of data. .. _histogram: https://en.wikipedia.org/wiki/Histogram Parameters ---------- column: str or list of str, optional Column name or list of names to be used for creating the hostogram plot. If None (default), all numeric columns will be used. bins : integer, default 10 Number of histogram bins to be used. **kwargs Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ["length", "width", "species"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.hist(bins=4) # doctest: +SKIP >>> df.plot.hist(column=["length", "width"]) # doctest: +SKIP >>> df.plot.hist(column="length", bins=4) # doctest: +SKIP """ return self(kind="hist", column=column, bins=bins, **kwargs)
class PySparkKdePlotBase: @staticmethod def get_ind(sdf: "DataFrame", ind: Optional[Union[Sequence[float], int]]) -> Sequence[float]: def calc_min_max() -> "Row": if len(sdf.columns) > 1: min_col = F.least(*map(F.min, sdf)) # type: ignore max_col = F.greatest(*map(F.max, sdf)) # type: ignore else: min_col = F.min(sdf.columns[-1]) max_col = F.max(sdf.columns[-1]) return sdf.select(min_col, max_col).first() # type: ignore if ind is None: min_val, max_val = calc_min_max() sample_range = max_val - min_val ind = NumpyHelper.linspace( min_val - 0.5 * sample_range, max_val + 0.5 * sample_range, 1000, ) elif isinstance(ind, int): min_val, max_val = calc_min_max() sample_range = max_val - min_val ind = NumpyHelper.linspace( min_val - 0.5 * sample_range, max_val + 0.5 * sample_range, ind, ) return ind @staticmethod def compute_kde_col( input_col: Column, bw_method: Union[int, float], ind: Sequence[float], ) -> Column: # refers to org.apache.spark.mllib.stat.KernelDensity assert bw_method is not None and isinstance( bw_method, (int, float) ), "'bw_method' must be set as a scalar number." assert ind is not None, "'ind' must be a scalar array." bandwidth = float(bw_method) log_std_plus_half_log2_pi = math.log(bandwidth) + 0.5 * math.log(2 * math.pi) def norm_pdf( mean: Column, std: Column, log_std_plus_half_log2_pi: Column, x: Column, ) -> Column: x0 = x - mean x1 = x0 / std log_density = -0.5 * x1 * x1 - log_std_plus_half_log2_pi return F.exp(log_density) return F.array( [ F.avg( norm_pdf( input_col.cast("double"), F.lit(bandwidth), F.lit(log_std_plus_half_log2_pi), F.lit(point), ) ) for point in ind ] ) class PySparkHistogramPlotBase: @staticmethod def get_bins(sdf: "DataFrame", bins: int) -> Sequence[float]: if len(sdf.columns) > 1: min_col = F.least(*map(F.min, sdf)) # type: ignore max_col = F.greatest(*map(F.max, sdf)) # type: ignore else: min_col = F.min(sdf.columns[-1]) max_col = F.max(sdf.columns[-1]) boundaries = sdf.select(min_col, max_col).first() if boundaries[0] == boundaries[1]: # type: ignore boundaries = (boundaries[0] - 0.5, boundaries[1] + 0.5) # type: ignore return NumpyHelper.linspace(boundaries[0], boundaries[1], bins + 1) # type: ignore @staticmethod def compute_hist(sdf: "DataFrame", bins: Sequence[float]) -> List["pd.Series"]: require_minimum_pandas_version() assert isinstance(bins, list) spark = sdf._session assert spark is not None # 1. Make the bucket output flat to: # +----------+--------+ # |__group_id|__bucket| # +----------+--------+ # |0 |0 | # |0 |0 | # |0 |1 | # |0 |2 | # |0 |3 | # |0 |3 | # |1 |0 | # |1 |1 | # |1 |1 | # |1 |2 | # |1 |1 | # |1 |0 | # +----------+--------+ colnames = sdf.columns # determines which bucket a given value falls into, based on predefined bin intervals # refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets def binary_search_for_buckets(value: Column) -> Column: index = SF.array_binary_search(F.lit(bins), value) bucket = F.when(index >= 0, index).otherwise(-index - 2) unboundErrMsg = F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]") return ( F.when(value == F.lit(bins[-1]), F.lit(len(bins) - 2)) .when(value.between(F.lit(bins[0]), F.lit(bins[-1])), bucket) .otherwise(F.raise_error(F.printf(unboundErrMsg, value))) ) output_df = ( sdf.select( F.posexplode( F.array([F.col(colname).cast("double") for colname in colnames]) ).alias("__group_id", "__value") ) .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()) .select( F.col("__group_id"), binary_search_for_buckets(F.col("__value")).alias("__bucket"), ) ) # 2. Calculate the count based on each group and bucket, also fill empty bins. # +----------+--------+------+ # |__group_id|__bucket| count| # +----------+--------+------+ # |0 |0 |2 | # |0 |1 |1 | # |0 |2 |1 | # |0 |3 |2 | # |1 |0 |2 | # |1 |1 |3 | # |1 |2 |1 | # |1 |3 |0 | <- fill empty bins with zeros (by joining with bin_df) # +----------+--------+------+ output_df = output_df.groupby("__group_id", "__bucket").agg(F.count("*").alias("count")) # Generate all possible combinations of group id and bucket bin_df = ( spark.range(len(colnames)) .select( F.col("id").alias("__group_id"), F.explode(F.lit(list(range(len(bins) - 1)))).alias("__bucket"), ) .hint("broadcast") ) output_df = ( bin_df.join(output_df, ["__group_id", "__bucket"], "left") .select("__group_id", "__bucket", F.nvl(F.col("count"), F.lit(0)).alias("count")) .coalesce(1) .sortWithinPartitions("__group_id", "__bucket") .select("__group_id", "count") ) # 3. Calculate based on each group id. From: # +----------+--------+------+ # |__group_id|__bucket| count| # +----------+--------+------+ # |0 |0 |2 | # |0 |1 |1 | # |0 |2 |1 | # |0 |3 |2 | # +----------+--------+------+ # +----------+--------+------+ # |__group_id|__bucket| count| # +----------+--------+------+ # |1 |0 |2 | # |1 |1 |3 | # |1 |2 |1 | # |1 |3 |0 | # +----------+--------+------+ # # to: # +-----------------+ # |__values1__bucket| # +-----------------+ # |2 | # |1 | # |1 | # |2 | # |0 | # +-----------------+ # +-----------------+ # |__values2__bucket| # +-----------------+ # |2 | # |3 | # |1 | # |0 | # |0 | # +-----------------+ result = output_df.toPandas() output_series = [] for i, input_column_name in enumerate(colnames): pdf = result[result["__group_id"] == i] pdf = pdf[["count"]] pdf.columns = [input_column_name] output_series.append(pdf[input_column_name]) return output_series class PySparkBoxPlotBase: @staticmethod def compute_box( sdf: "DataFrame", colnames: List[str], whis: float, precision: float, showfliers: bool ) -> Optional["Row"]: assert len(colnames) > 0 formatted_colnames = ["`{}`".format(colname) for colname in colnames] stats_scols = [] for i, colname in enumerate(formatted_colnames): percentiles = F.percentile_approx(colname, [0.25, 0.50, 0.75], int(1.0 / precision)) q1 = F.get(percentiles, 0) med = F.get(percentiles, 1) q3 = F.get(percentiles, 2) iqr = q3 - q1 lfence = q1 - F.lit(whis) * iqr ufence = q3 + F.lit(whis) * iqr stats_scols.append( F.struct( F.mean(colname).alias("mean"), med.alias("med"), q1.alias("q1"), q3.alias("q3"), lfence.alias("lfence"), ufence.alias("ufence"), ).alias(f"_box_plot_stats_{i}") ) sdf_stats = sdf.select(*stats_scols) result_scols = [] for i, colname in enumerate(formatted_colnames): value = F.col(colname) lfence = F.col(f"_box_plot_stats_{i}.lfence") ufence = F.col(f"_box_plot_stats_{i}.ufence") mean = F.col(f"_box_plot_stats_{i}.mean") med = F.col(f"_box_plot_stats_{i}.med") q1 = F.col(f"_box_plot_stats_{i}.q1") q3 = F.col(f"_box_plot_stats_{i}.q3") outlier = ~value.between(lfence, ufence) # Computes min and max values of non-outliers - the whiskers upper_whisker = F.max(F.when(~outlier, value).otherwise(F.lit(None))) lower_whisker = F.min(F.when(~outlier, value).otherwise(F.lit(None))) # If it shows fliers, take the top 1k with the highest absolute values # Here we normalize the values by subtracting the median. if showfliers: pair = F.when( outlier, F.struct(F.abs(value - med), value.alias("val")), ).otherwise(F.lit(None)) topk = SF.collect_top_k(pair, 1001, False) fliers = F.when(F.size(topk) > 0, topk["val"]).otherwise(F.lit(None)) else: fliers = F.lit(None) result_scols.append( F.struct( F.first(mean).alias("mean"), F.first(med).alias("med"), F.first(q1).alias("q1"), F.first(q3).alias("q3"), upper_whisker.alias("upper_whisker"), lower_whisker.alias("lower_whisker"), fliers.alias("fliers"), ).alias(f"_box_plot_results_{i}") ) sdf_result = sdf.join(sdf_stats.hint("broadcast")).select(*result_scols) return sdf_result.first()