## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importmathfromtypingimportAny,TYPE_CHECKING,List,Optional,Union,SequencefromtypesimportModuleTypefrompyspark.errorsimportPySparkValueErrorfrompyspark.sqlimportColumn,functionsasFfrompyspark.sql.internalimportInternalFunctionasSFfrompyspark.sql.pandas.utilsimportrequire_minimum_pandas_versionfrompyspark.sql.utilsimportNumpyHelper,require_minimum_plotly_versionifTYPE_CHECKING:frompyspark.sqlimportDataFrame,Rowimportpandasaspdfromplotly.graph_objsimportFigureclassPySparkTopNPlotBase:defget_top_n(self,sdf:"DataFrame")->"pd.DataFrame":max_rows=int(sdf._session.conf.get("spark.sql.pyspark.plotting.max_rows")# type: ignore[arg-type])pdf=sdf.limit(max_rows+1).toPandas()self.partial=Falseiflen(pdf)>max_rows:self.partial=Truepdf=pdf.iloc[:max_rows]returnpdfclassPySparkSampledPlotBase:defget_sampled(self,sdf:"DataFrame")->"pd.DataFrame":frompyspark.sqlimportObservation,functionsasFmax_rows=int(sdf._session.conf.get("spark.sql.pyspark.plotting.max_rows")# type: ignore[arg-type])observation=Observation("pyspark plotting")rand_col_name="__pyspark_plotting_sampled_plot_base_rand__"id_col_name="__pyspark_plotting_sampled_plot_base_id__"sampled_sdf=(sdf.observe(observation,F.count(F.lit(1)).alias("count")).select("*",F.rand().alias(rand_col_name),F.monotonically_increasing_id().alias(id_col_name),).sort(rand_col_name).limit(max_rows+1).coalesce(1).sortWithinPartitions(id_col_name).drop(rand_col_name,id_col_name))pdf=sampled_sdf.toPandas()iflen(pdf)>max_rows:try:self.fraction=float(max_rows)/observation.get["count"]exceptException:passreturnpdf[:max_rows]else:self.fraction=1.0returnpdfclassPySparkPlotAccessor:""" Accessor for DataFrame plotting functionality in PySpark. Users can call the accessor as ``df.plot(kind="line")`` or use the dedicated methods like ``df.plot.line(...)`` to generate plots. """plot_data_map={"area":PySparkSampledPlotBase().get_sampled,"bar":PySparkTopNPlotBase().get_top_n,"barh":PySparkTopNPlotBase().get_top_n,"line":PySparkSampledPlotBase().get_sampled,"pie":PySparkTopNPlotBase().get_top_n,"scatter":PySparkSampledPlotBase().get_sampled,}_backends={}# type: ignore[var-annotated]def__init__(self,data:"DataFrame"):self.data=datadef__call__(self,kind:str="line",backend:Optional[str]=None,**kwargs:Any)->"Figure":plot_backend=PySparkPlotAccessor._get_plot_backend(backend)returnplot_backend.plot_pyspark(self.data,kind=kind,**kwargs)@staticmethoddef_get_plot_backend(backend:Optional[str]=None)->ModuleType:backend=backendor"plotly"ifbackendinPySparkPlotAccessor._backends:returnPySparkPlotAccessor._backends[backend]ifbackend=="plotly":require_minimum_plotly_version()else:raisePySparkValueError(errorClass="UNSUPPORTED_PLOT_BACKEND",messageParameters={"backend":backend,"supported_backends":", ".join(["plotly"])},)frompyspark.sql.plotimportplotlyasmodulereturnmodule
[docs]defline(self,x:str,y:Union[str,list[str]],**kwargs:Any)->"Figure":""" Plot DataFrame as lines. Parameters ---------- x : str Name of column to use for the horizontal axis. y : str or list of str Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. **kwargs : optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.line(x="category", y="int_val") # doctest: +SKIP >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP """returnself(kind="line",x=x,y=y,**kwargs)
[docs]defbar(self,x:str,y:Union[str,list[str]],**kwargs:Any)->"Figure":""" Vertical bar plot. A bar plot is a plot that presents categorical data with rectangular bars with lengths proportional to the values that they represent. A bar plot shows comparisons among discrete categories. One axis of the plot shows the specific categories being compared, and the other axis represents a measured value. Parameters ---------- x : str Name of column to use for the horizontal axis. y : str or list of str Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. **kwargs : optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.bar(x="category", y="int_val") # doctest: +SKIP >>> df.plot.bar(x="category", y=["int_val", "float_val"]) # doctest: +SKIP """returnself(kind="bar",x=x,y=y,**kwargs)
[docs]defbarh(self,x:str,y:Union[str,list[str]],**kwargs:Any)->"Figure":""" Make a horizontal bar plot. A horizontal bar plot is a plot that presents quantitative data with rectangular bars with lengths proportional to the values that they represent. A bar plot shows comparisons among discrete categories. One axis of the plot shows the specific categories being compared, and the other axis represents a measured value. Parameters ---------- x : str or list of str Name(s) of the column(s) to use for the horizontal axis. Multiple columns can be plotted. y : str or list of str Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. **kwargs : optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Notes ----- In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs. In Plotly, `x` refers to the values and `y` refers to the categories. In Matplotlib, `x` refers to the categories and `y` refers to the values. Ensure correct axis labeling based on the backend used. Examples -------- >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.barh(x="int_val", y="category") # doctest: +SKIP >>> df.plot.barh( ... x=["int_val", "float_val"], y="category" ... ) # doctest: +SKIP """returnself(kind="barh",x=x,y=y,**kwargs)
[docs]defscatter(self,x:str,y:str,**kwargs:Any)->"Figure":""" Create a scatter plot with varying marker point size and color. The coordinates of each point are defined by two dataframe columns and filled circles are used to represent each point. This kind of plot is useful to see complex correlations between two variables. Points could be for instance natural 2D coordinates like longitude and latitude in a map or, in general, any pair of metrics that can be plotted against each other. Parameters ---------- x : str Name of column to use as horizontal coordinates for each point. y : str or list of str Name of column to use as vertical coordinates for each point. **kwargs: Optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ['length', 'width', 'species'] >>> df = spark.createDataFrame(data, columns) >>> df.plot.scatter(x='length', y='width') # doctest: +SKIP """returnself(kind="scatter",x=x,y=y,**kwargs)
[docs]defarea(self,x:str,y:Union[str,list[str]],**kwargs:Any)->"Figure":""" Draw a stacked area plot. An area plot displays quantitative data visually. Parameters ---------- x : str Name of column to use for the horizontal axis. y : str or list of str Name(s) of the column(s) to plot. **kwargs: Optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> from datetime import datetime >>> data = [ ... (3, 5, 20, datetime(2018, 1, 31)), ... (2, 5, 42, datetime(2018, 2, 28)), ... (3, 6, 28, datetime(2018, 3, 31)), ... (9, 12, 62, datetime(2018, 4, 30)) ... ] >>> columns = ["sales", "signups", "visits", "date"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.area(x='date', y=['sales', 'signups', 'visits']) # doctest: +SKIP """returnself(kind="area",x=x,y=y,**kwargs)
[docs]defpie(self,x:str,y:Optional[str],**kwargs:Any)->"Figure":""" Generate a pie plot. A pie plot is a proportional representation of the numerical data in a column. Parameters ---------- x : str Name of column to be used as the category labels for the pie plot. y : str, optional Name of the column to plot. If not provided, `subplots=True` must be passed at `kwargs`. **kwargs Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> from datetime import datetime >>> data = [ ... (3, 5, 20, datetime(2018, 1, 31)), ... (2, 5, 42, datetime(2018, 2, 28)), ... (3, 6, 28, datetime(2018, 3, 31)), ... (9, 12, 62, datetime(2018, 4, 30)) ... ] >>> columns = ["sales", "signups", "visits", "date"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.pie(x='date', y='sales') # doctest: +SKIP >>> df.plot.pie(x='date', subplots=True) # doctest: +SKIP """returnself(kind="pie",x=x,y=y,**kwargs)
[docs]defbox(self,column:Optional[Union[str,List[str]]]=None,**kwargs:Any)->"Figure":""" Make a box plot of the DataFrame columns. Make a box-and-whisker plot from DataFrame columns, optionally grouped by some other columns. A box plot is a method for graphically depicting groups of numerical data through their quartiles. The box extends from the Q1 to Q3 quartile values of the data, with a line at the median (Q2). The whiskers extend from the edges of box to show the range of the data. By default, they extend no more than 1.5 * IQR (IQR = Q3 - Q1) from the edges of the box, ending at the farthest data point within that interval. Outliers are plotted as separate dots. Parameters ---------- column: str or list of str, optional Column name or list of names to be used for creating the box plot. If None (default), all numeric columns will be used. **kwargs Extra arguments to `precision`: refer to a float that is used by pyspark to compute approximate statistics for building a boxplot. The default value is 0.01. Use smaller values to get more precise statistics. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [ ... ("A", 50, 55), ... ("B", 55, 60), ... ("C", 60, 65), ... ("D", 65, 70), ... ("E", 70, 75), ... ("F", 10, 15), ... ("G", 85, 90), ... ("H", 5, 150), ... ] >>> columns = ["student", "math_score", "english_score"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.box() # doctest: +SKIP >>> df.plot.box(column="math_score") # doctest: +SKIP >>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP """returnself(kind="box",column=column,**kwargs)
[docs]defkde(self,bw_method:Union[int,float],column:Optional[Union[str,List[str]]]=None,ind:Optional[Union[Sequence[float],int]]=None,**kwargs:Any,)->"Figure":""" Generate Kernel Density Estimate plot using Gaussian kernels. In statistics, kernel density estimation (KDE) is a non-parametric way to estimate the probability density function (PDF) of a random variable. This function uses Gaussian kernels and includes automatic bandwidth determination. Parameters ---------- bw_method : int or float The method used to calculate the estimator bandwidth. See KernelDensity in PySpark for more information. column: str or list of str, optional Column name or list of names to be used for creating the kde plot. If None (default), all numeric columns will be used. ind : List of float, NumPy array or integer, optional Evaluation points for the estimated PDF. If None (default), 1000 equally spaced points are used. If `ind` is a NumPy array, the KDE is evaluated at the points passed. If `ind` is an integer, `ind` number of equally spaced points are used. **kwargs : optional Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ["length", "width", "species"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.kde(bw_method=0.3, ind=100) # doctest: +SKIP >>> df.plot.kde(column=["length", "width"], bw_method=0.3, ind=100) # doctest: +SKIP >>> df.plot.kde(column="length", bw_method=0.3, ind=100) # doctest: +SKIP """returnself(kind="kde",column=column,bw_method=bw_method,ind=ind,**kwargs)
[docs]defhist(self,column:Optional[Union[str,List[str]]]=None,bins:int=10,**kwargs:Any)->"Figure":""" Draw one histogram of the DataFrame’s columns. A `histogram`_ is a representation of the distribution of data. .. _histogram: https://en.wikipedia.org/wiki/Histogram Parameters ---------- column: str or list of str, optional Column name or list of names to be used for creating the hostogram plot. If None (default), all numeric columns will be used. bins : integer, default 10 Number of histogram bins to be used. **kwargs Additional keyword arguments. Returns ------- :class:`plotly.graph_objs.Figure` Examples -------- >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ["length", "width", "species"] >>> df = spark.createDataFrame(data, columns) >>> df.plot.hist(bins=4) # doctest: +SKIP >>> df.plot.hist(column=["length", "width"]) # doctest: +SKIP >>> df.plot.hist(column="length", bins=4) # doctest: +SKIP """returnself(kind="hist",column=column,bins=bins,**kwargs)
classPySparkKdePlotBase:@staticmethoddefget_ind(sdf:"DataFrame",ind:Optional[Union[Sequence[float],int]])->Sequence[float]:defcalc_min_max()->"Row":iflen(sdf.columns)>1:min_col=F.least(*map(F.min,sdf))# type: ignoremax_col=F.greatest(*map(F.max,sdf))# type: ignoreelse:min_col=F.min(sdf.columns[-1])max_col=F.max(sdf.columns[-1])returnsdf.select(min_col,max_col).first()# type: ignoreifindisNone:min_val,max_val=calc_min_max()sample_range=max_val-min_valind=NumpyHelper.linspace(min_val-0.5*sample_range,max_val+0.5*sample_range,1000,)elifisinstance(ind,int):min_val,max_val=calc_min_max()sample_range=max_val-min_valind=NumpyHelper.linspace(min_val-0.5*sample_range,max_val+0.5*sample_range,ind,)returnind@staticmethoddefcompute_kde_col(input_col:Column,bw_method:Union[int,float],ind:Sequence[float],)->Column:# refers to org.apache.spark.mllib.stat.KernelDensityassertbw_methodisnotNoneandisinstance(bw_method,(int,float)),"'bw_method' must be set as a scalar number."assertindisnotNone,"'ind' must be a scalar array."bandwidth=float(bw_method)log_std_plus_half_log2_pi=math.log(bandwidth)+0.5*math.log(2*math.pi)defnorm_pdf(mean:Column,std:Column,log_std_plus_half_log2_pi:Column,x:Column,)->Column:x0=x-meanx1=x0/stdlog_density=-0.5*x1*x1-log_std_plus_half_log2_pireturnF.exp(log_density)returnF.array([F.avg(norm_pdf(input_col.cast("double"),F.lit(bandwidth),F.lit(log_std_plus_half_log2_pi),F.lit(point),))forpointinind])classPySparkHistogramPlotBase:@staticmethoddefget_bins(sdf:"DataFrame",bins:int)->Sequence[float]:iflen(sdf.columns)>1:min_col=F.least(*map(F.min,sdf))# type: ignoremax_col=F.greatest(*map(F.max,sdf))# type: ignoreelse:min_col=F.min(sdf.columns[-1])max_col=F.max(sdf.columns[-1])boundaries=sdf.select(min_col,max_col).first()ifboundaries[0]==boundaries[1]:# type: ignoreboundaries=(boundaries[0]-0.5,boundaries[1]+0.5)# type: ignorereturnNumpyHelper.linspace(boundaries[0],boundaries[1],bins+1)# type: ignore@staticmethoddefcompute_hist(sdf:"DataFrame",bins:Sequence[float])->List["pd.Series"]:require_minimum_pandas_version()assertisinstance(bins,list)spark=sdf._sessionassertsparkisnotNone# 1. Make the bucket output flat to:# +----------+--------+# |__group_id|__bucket|# +----------+--------+# |0 |0 |# |0 |0 |# |0 |1 |# |0 |2 |# |0 |3 |# |0 |3 |# |1 |0 |# |1 |1 |# |1 |1 |# |1 |2 |# |1 |1 |# |1 |0 |# +----------+--------+colnames=sdf.columns# determines which bucket a given value falls into, based on predefined bin intervals# refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBucketsdefbinary_search_for_buckets(value:Column)->Column:index=SF.array_binary_search(F.lit(bins),value)bucket=F.when(index>=0,index).otherwise(-index-2)unboundErrMsg=F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]")return(F.when(value==F.lit(bins[-1]),F.lit(len(bins)-2)).when(value.between(F.lit(bins[0]),F.lit(bins[-1])),bucket).otherwise(F.raise_error(F.printf(unboundErrMsg,value))))output_df=(sdf.select(F.posexplode(F.array([F.col(colname).cast("double")forcolnameincolnames])).alias("__group_id","__value")).where(F.col("__value").isNotNull()&~F.col("__value").isNaN()).select(F.col("__group_id"),binary_search_for_buckets(F.col("__value")).alias("__bucket"),))# 2. Calculate the count based on each group and bucket, also fill empty bins.# +----------+--------+------+# |__group_id|__bucket| count|# +----------+--------+------+# |0 |0 |2 |# |0 |1 |1 |# |0 |2 |1 |# |0 |3 |2 |# |1 |0 |2 |# |1 |1 |3 |# |1 |2 |1 |# |1 |3 |0 | <- fill empty bins with zeros (by joining with bin_df)# +----------+--------+------+output_df=output_df.groupby("__group_id","__bucket").agg(F.count("*").alias("count"))# Generate all possible combinations of group id and bucketbin_df=(spark.range(len(colnames)).select(F.col("id").alias("__group_id"),F.explode(F.lit(list(range(len(bins)-1)))).alias("__bucket"),).hint("broadcast"))output_df=(bin_df.join(output_df,["__group_id","__bucket"],"left").select("__group_id","__bucket",F.nvl(F.col("count"),F.lit(0)).alias("count")).coalesce(1).sortWithinPartitions("__group_id","__bucket").select("__group_id","count"))# 3. Calculate based on each group id. From:# +----------+--------+------+# |__group_id|__bucket| count|# +----------+--------+------+# |0 |0 |2 |# |0 |1 |1 |# |0 |2 |1 |# |0 |3 |2 |# +----------+--------+------+# +----------+--------+------+# |__group_id|__bucket| count|# +----------+--------+------+# |1 |0 |2 |# |1 |1 |3 |# |1 |2 |1 |# |1 |3 |0 |# +----------+--------+------+## to:# +-----------------+# |__values1__bucket|# +-----------------+# |2 |# |1 |# |1 |# |2 |# |0 |# +-----------------+# +-----------------+# |__values2__bucket|# +-----------------+# |2 |# |3 |# |1 |# |0 |# |0 |# +-----------------+result=output_df.toPandas()output_series=[]fori,input_column_nameinenumerate(colnames):pdf=result[result["__group_id"]==i]pdf=pdf[["count"]]pdf.columns=[input_column_name]output_series.append(pdf[input_column_name])returnoutput_seriesclassPySparkBoxPlotBase:@staticmethoddefcompute_box(sdf:"DataFrame",colnames:List[str],whis:float,precision:float,showfliers:bool)->Optional["Row"]:assertlen(colnames)>0formatted_colnames=["`{}`".format(colname)forcolnameincolnames]stats_scols=[]fori,colnameinenumerate(formatted_colnames):percentiles=F.percentile_approx(colname,[0.25,0.50,0.75],int(1.0/precision))q1=F.get(percentiles,0)med=F.get(percentiles,1)q3=F.get(percentiles,2)iqr=q3-q1lfence=q1-F.lit(whis)*iqrufence=q3+F.lit(whis)*iqrstats_scols.append(F.mean(colname).alias(f"mean_{i}"))stats_scols.append(med.alias(f"med_{i}"))stats_scols.append(q1.alias(f"q1_{i}"))stats_scols.append(q3.alias(f"q3_{i}"))stats_scols.append(lfence.alias(f"lfence_{i}"))stats_scols.append(ufence.alias(f"ufence_{i}"))# compute all stats with a scalar subquerystats_col="__pyspark_plotting_box_plot_stats__"sdf=sdf.select("*",sdf.select(F.struct(*stats_scols)).scalar().alias(stats_col))result_scols=[]fori,colnameinenumerate(formatted_colnames):value=F.col(colname)lfence=F.col(f"{stats_col}.lfence_{i}")ufence=F.col(f"{stats_col}.ufence_{i}")mean=F.col(f"{stats_col}.mean_{i}")med=F.col(f"{stats_col}.med_{i}")q1=F.col(f"{stats_col}.q1_{i}")q3=F.col(f"{stats_col}.q3_{i}")outlier=~value.between(lfence,ufence)# Computes min and max values of non-outliers - the whiskersupper_whisker=F.max(F.when(~outlier,value).otherwise(F.lit(None)))lower_whisker=F.min(F.when(~outlier,value).otherwise(F.lit(None)))# If it shows fliers, take the top 1k with the highest absolute values# Here we normalize the values by subtracting the median.ifshowfliers:pair=F.when(outlier,F.struct(F.abs(value-med),value.alias("val")),).otherwise(F.lit(None))topk=SF.collect_top_k(pair,1001,False)fliers=F.when(F.size(topk)>0,topk["val"]).otherwise(F.lit(None))else:fliers=F.lit(None)result_scols.append(F.struct(F.first(mean).alias("mean"),F.first(med).alias("med"),F.first(q1).alias("q1"),F.first(q3).alias("q3"),upper_whisker.alias("upper_whisker"),lower_whisker.alias("lower_whisker"),fliers.alias("fliers"),).alias(f"_box_plot_results_{i}"))returnsdf.select(*result_scols).first()