MLlib (DataFrame-based)#

Pipeline APIs#

Transformer()

Abstract class for transformers that transform one dataset into another.

UnaryTransformer()

Abstract class for transformers that take one input column, apply transformation, and output the result as a new column.

Estimator()

Abstract class for estimators that fit models to data.

Model()

Abstract class for models that are fitted by estimators.

Predictor()

Estimator for prediction tasks (regression and classification).

PredictionModel()

Model for prediction tasks (regression and classification).

Pipeline(*[, stages])

A simple pipeline, which acts as an estimator.

PipelineModel(stages)

Represents a compiled pipeline with transformers and fitted models.

Parameters#

Param(parent, name, doc[, typeConverter])

A param with self-contained documentation.

Params()

Components that take parameters.

TypeConverters()

Factory methods for common type conversion functions for Param.typeConverter.

Feature#

Binarizer(*[, threshold, inputCol, ...])

Binarize a column of continuous features given a threshold.

BucketedRandomProjectionLSH(*[, inputCol, ...])

LSH class for Euclidean distance metrics.

BucketedRandomProjectionLSHModel([java_model])

Model fitted by BucketedRandomProjectionLSH, where multiple random vectors are stored.

Bucketizer(*[, splits, inputCol, outputCol, ...])

Maps a column of continuous features to a column of feature buckets.

ChiSqSelector(*[, numTopFeatures, ...])

Chi-Squared feature selection, which selects categorical features to use for predicting a categorical label.

ChiSqSelectorModel([java_model])

Model fitted by ChiSqSelector.

CountVectorizer(*[, minTF, minDF, maxDF, ...])

Extracts a vocabulary from document collections and generates a CountVectorizerModel.

CountVectorizerModel([java_model])

Model fitted by CountVectorizer.

DCT(*[, inverse, inputCol, outputCol])

A feature transformer that takes the 1D discrete cosine transform of a real vector.

ElementwiseProduct(*[, scalingVec, ...])

Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a provided "weight" vector.

FeatureHasher(*[, numFeatures, inputCols, ...])

Feature hashing projects a set of categorical or numerical features into a feature vector of specified dimension (typically substantially smaller than that of the original feature space).

HashingTF(*[, numFeatures, binary, ...])

Maps a sequence of terms to their term frequencies using the hashing trick.

IDF(*[, minDocFreq, inputCol, outputCol])

Compute the Inverse Document Frequency (IDF) given a collection of documents.

IDFModel([java_model])

Model fitted by IDF.

Imputer(*[, strategy, missingValue, ...])

Imputation estimator for completing missing values, using the mean, median or mode of the columns in which the missing values are located.

ImputerModel([java_model])

Model fitted by Imputer.

IndexToString(*[, inputCol, outputCol, labels])

A pyspark.ml.base.Transformer that maps a column of indices back to a new column of corresponding string values.

Interaction(*[, inputCols, outputCol])

Implements the feature interaction transform.

MaxAbsScaler(*[, inputCol, outputCol])

Rescale each feature individually to range [-1, 1] by dividing through the largest maximum absolute value in each feature.

MaxAbsScalerModel([java_model])

Model fitted by MaxAbsScaler.

MinHashLSH(*[, inputCol, outputCol, seed, ...])

LSH class for Jaccard distance.

MinHashLSHModel([java_model])

Model produced by MinHashLSH, where where multiple hash functions are stored.

MinMaxScaler(*[, min, max, inputCol, outputCol])

Rescale each feature individually to a common range [min, max] linearly using column summary statistics, which is also known as min-max normalization or Rescaling.

MinMaxScalerModel([java_model])

Model fitted by MinMaxScaler.

NGram(*[, n, inputCol, outputCol])

A feature transformer that converts the input array of strings into an array of n-grams.

Normalizer(*[, p, inputCol, outputCol])

Normalize a vector to have unit norm using the given p-norm.

OneHotEncoder(*[, inputCols, outputCols, ...])

A one-hot encoder that maps a column of category indices to a column of binary vectors, with at most a single one-value per row that indicates the input category index.

OneHotEncoderModel([java_model])

Model fitted by OneHotEncoder.

PCA(*[, k, inputCol, outputCol])

PCA trains a model to project vectors to a lower dimensional space of the top k principal components.

PCAModel([java_model])

Model fitted by PCA.

PolynomialExpansion(*[, degree, inputCol, ...])

Perform feature expansion in a polynomial space.

QuantileDiscretizer(*[, numBuckets, ...])

QuantileDiscretizer takes a column with continuous features and outputs a column with binned categorical features.

RobustScaler(*[, lower, upper, ...])

RobustScaler removes the median and scales the data according to the quantile range.

RobustScalerModel([java_model])

Model fitted by RobustScaler.

RegexTokenizer(*[, minTokenLength, gaps, ...])

A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text (default) or repeatedly matching the regex (if gaps is false).

RFormula(*[, formula, featuresCol, ...])

Implements the transforms required for fitting a dataset against an R model formula.

RFormulaModel([java_model])

Model fitted by RFormula.

SQLTransformer(*[, statement])

Implements the transforms which are defined by SQL statement.

StandardScaler(*[, withMean, withStd, ...])

Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set.

StandardScalerModel([java_model])

Model fitted by StandardScaler.

StopWordsRemover(*[, inputCol, outputCol, ...])

A feature transformer that filters out stop words from input.

StringIndexer(*[, inputCol, outputCol, ...])

A label indexer that maps a string column of labels to an ML column of label indices.

StringIndexerModel([java_model])

Model fitted by StringIndexer.

TargetEncoder(*[, inputCols, outputCols, ...])

Target Encoding maps a column of categorical indices into a numerical feature derived from the target.

TargetEncoderModel([java_model])

Model fitted by TargetEncoder.

Tokenizer(*[, inputCol, outputCol])

A tokenizer that converts the input string to lowercase and then splits it by white spaces.

UnivariateFeatureSelector(*[, featuresCol, ...])

Feature selector based on univariate statistical tests against labels.

UnivariateFeatureSelectorModel([java_model])

Model fitted by UnivariateFeatureSelector.

VarianceThresholdSelector(*[, featuresCol, ...])

Feature selector that removes all low-variance features.

VarianceThresholdSelectorModel([java_model])

Model fitted by VarianceThresholdSelector.

VectorAssembler(*[, inputCols, outputCol, ...])

A feature transformer that merges multiple columns into a vector column.

VectorIndexer(*[, maxCategories, inputCol, ...])

Class for indexing categorical feature columns in a dataset of Vector.

VectorIndexerModel([java_model])

Model fitted by VectorIndexer.

VectorSizeHint(*[, inputCol, size, ...])

A feature transformer that adds size information to the metadata of a vector column.

VectorSlicer(*[, inputCol, outputCol, ...])

This class takes a feature vector and outputs a new feature vector with a subarray of the original features.

Word2Vec(*[, vectorSize, minCount, ...])

Word2Vec trains a model of Map(String, Vector), i.e. transforms a word into a code for further natural language processing or machine learning process.

Word2VecModel([java_model])

Model fitted by Word2Vec.

Classification#

LinearSVC(*[, featuresCol, labelCol, ...])

This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.

LinearSVCModel([java_model])

Model fitted by LinearSVC.

LinearSVCSummary([java_obj])

Abstraction for LinearSVC Results for a given model.

LinearSVCTrainingSummary([java_obj])

Abstraction for LinearSVC Training results.

LogisticRegression(*[, featuresCol, ...])

Logistic regression.

LogisticRegressionModel([java_model])

Model fitted by LogisticRegression.

LogisticRegressionSummary([java_obj])

Abstraction for Logistic Regression Results for a given model.

LogisticRegressionTrainingSummary([java_obj])

Abstraction for multinomial Logistic Regression Training results.

BinaryLogisticRegressionSummary([java_obj])

Binary Logistic regression results for a given model.

BinaryLogisticRegressionTrainingSummary([...])

Binary Logistic regression training results for a given model.

DecisionTreeClassifier(*[, featuresCol, ...])

Decision tree learning algorithm for classification.

DecisionTreeClassificationModel([java_model])

Model fitted by DecisionTreeClassifier.

GBTClassifier(*[, featuresCol, labelCol, ...])

Gradient-Boosted Trees (GBTs) learning algorithm for classification.

GBTClassificationModel([java_model])

Model fitted by GBTClassifier.

RandomForestClassifier(*[, featuresCol, ...])

Random Forest learning algorithm for classification.

RandomForestClassificationModel([java_model])

Model fitted by RandomForestClassifier.

RandomForestClassificationSummary([java_obj])

Abstraction for RandomForestClassification Results for a given model.

RandomForestClassificationTrainingSummary([...])

Abstraction for RandomForestClassificationTraining Training results.

BinaryRandomForestClassificationSummary([...])

BinaryRandomForestClassification results for a given model.

BinaryRandomForestClassificationTrainingSummary([...])

BinaryRandomForestClassification training results for a given model.

NaiveBayes(*[, featuresCol, labelCol, ...])

Naive Bayes Classifiers.

NaiveBayesModel([java_model])

Model fitted by NaiveBayes.

MultilayerPerceptronClassifier(*[, ...])

Classifier trainer based on the Multilayer Perceptron.

MultilayerPerceptronClassificationModel([...])

Model fitted by MultilayerPerceptronClassifier.

MultilayerPerceptronClassificationSummary([...])

Abstraction for MultilayerPerceptronClassifier Results for a given model.

MultilayerPerceptronClassificationTrainingSummary([...])

Abstraction for MultilayerPerceptronClassifier Training results.

OneVsRest(*[, featuresCol, labelCol, ...])

Reduction of Multiclass Classification to Binary Classification.

OneVsRestModel(models)

Model fitted by OneVsRest.

FMClassifier(*[, featuresCol, labelCol, ...])

Factorization Machines learning algorithm for classification.

FMClassificationModel([java_model])

Model fitted by FMClassifier.

FMClassificationSummary([java_obj])

Abstraction for FMClassifier Results for a given model.

FMClassificationTrainingSummary([java_obj])

Abstraction for FMClassifier Training results.

Clustering#

BisectingKMeans(*[, featuresCol, ...])

A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark.

BisectingKMeansModel([java_model])

Model fitted by BisectingKMeans.

BisectingKMeansSummary([java_obj])

Bisecting KMeans clustering results for a given model.

KMeans(*[, featuresCol, predictionCol, k, ...])

K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al).

KMeansModel([java_model])

Model fitted by KMeans.

KMeansSummary([java_obj])

Summary of KMeans.

GaussianMixture(*[, featuresCol, ...])

GaussianMixture clustering.

GaussianMixtureModel([java_model])

Model fitted by GaussianMixture.

GaussianMixtureSummary([java_obj])

Gaussian mixture clustering results for a given model.

LDA(*[, featuresCol, maxIter, seed, ...])

Latent Dirichlet Allocation (LDA), a topic model designed for text documents.

LDAModel([java_model])

Latent Dirichlet Allocation (LDA) model.

LocalLDAModel([java_model])

Local (non-distributed) model fitted by LDA.

DistributedLDAModel([java_model])

Distributed model fitted by LDA.

PowerIterationClustering(*[, k, maxIter, ...])

Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and Cohen.

Functions#

array_to_vector(col)

Converts a column of array of numeric type into a column of pyspark.ml.linalg.DenseVector instances

vector_to_array(col[, dtype])

Converts a column of MLlib sparse/dense vectors into a column of dense arrays.

predict_batch_udf(make_predict_fn, *, ...[, ...])

Given a function which loads a model and returns a predict function for inference over a batch of numpy inputs, returns a Pandas UDF wrapper for inference over a Spark DataFrame.

Vector and Matrix#

Vector()

DenseVector(ar)

A dense vector represented by a value array.

SparseVector(size, *args)

A simple sparse vector class for passing data to MLlib.

Vectors()

Factory methods for working with vectors.

Matrix(numRows, numCols[, isTransposed])

DenseMatrix(numRows, numCols, values[, ...])

Column-major dense matrix.

SparseMatrix(numRows, numCols, colPtrs, ...)

Sparse Matrix stored in CSC format.

Matrices()

Recommendation#

ALS(*[, rank, maxIter, regParam, ...])

Alternating Least Squares (ALS) matrix factorization.

ALSModel([java_model])

Model fitted by ALS.

Regression#

AFTSurvivalRegression(*[, featuresCol, ...])

Accelerated Failure Time (AFT) Model Survival Regression

AFTSurvivalRegressionModel([java_model])

Model fitted by AFTSurvivalRegression.

DecisionTreeRegressor(*[, featuresCol, ...])

Decision tree learning algorithm for regression.

DecisionTreeRegressionModel([java_model])

Model fitted by DecisionTreeRegressor.

GBTRegressor(*[, featuresCol, labelCol, ...])

Gradient-Boosted Trees (GBTs) learning algorithm for regression.

GBTRegressionModel([java_model])

Model fitted by GBTRegressor.

GeneralizedLinearRegression(*[, labelCol, ...])

Generalized Linear Regression.

GeneralizedLinearRegressionModel([java_model])

Model fitted by GeneralizedLinearRegression.

GeneralizedLinearRegressionSummary([java_obj])

Generalized linear regression results evaluated on a dataset.

GeneralizedLinearRegressionTrainingSummary([...])

Generalized linear regression training results.

IsotonicRegression(*[, featuresCol, ...])

Currently implemented using parallelized pool adjacent violators algorithm.

IsotonicRegressionModel([java_model])

Model fitted by IsotonicRegression.

LinearRegression(*[, featuresCol, labelCol, ...])

Linear regression.

LinearRegressionModel([java_model])

Model fitted by LinearRegression.

LinearRegressionSummary([java_obj])

Linear regression results evaluated on a dataset.

LinearRegressionTrainingSummary([java_obj])

Linear regression training results.

RandomForestRegressor(*[, featuresCol, ...])

Random Forest learning algorithm for regression.

RandomForestRegressionModel([java_model])

Model fitted by RandomForestRegressor.

FMRegressor(*[, featuresCol, labelCol, ...])

Factorization Machines learning algorithm for regression.

FMRegressionModel([java_model])

Model fitted by FMRegressor.

Statistics#

ChiSquareTest()

Conduct Pearson's independence test for every feature against the label.

Correlation()

Compute the correlation matrix for the input dataset of Vectors using the specified method.

KolmogorovSmirnovTest()

Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous distribution.

MultivariateGaussian(mean, cov)

Represents a (mean, cov) tuple

Summarizer()

Tools for vectorized statistics on MLlib Vectors.

SummaryBuilder(jSummaryBuilder)

A builder object that provides summary statistics about a given column.

Tuning#

ParamGridBuilder()

Builder for a param grid used in grid search-based model selection.

CrossValidator(*[, estimator, ...])

K-fold cross validation performs model selection by splitting the dataset into a set of non-overlapping randomly partitioned folds which are used as separate training and test datasets e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing.

CrossValidatorModel(bestModel[, avgMetrics, ...])

CrossValidatorModel contains the model with the highest average cross-validation metric across folds and uses this model to transform input data.

TrainValidationSplit(*[, estimator, ...])

Validation for hyper-parameter tuning.

TrainValidationSplitModel(bestModel[, ...])

Model from train validation split.

Evaluation#

Evaluator()

Base class for evaluators that compute metrics from predictions.

BinaryClassificationEvaluator(*[, ...])

Evaluator for binary classification, which expects input columns rawPrediction, label and an optional weight column.

RegressionEvaluator(*[, predictionCol, ...])

Evaluator for Regression, which expects input columns prediction, label and an optional weight column.

MulticlassClassificationEvaluator(*[, ...])

Evaluator for Multiclass Classification, which expects input columns: prediction, label, weight (optional) and probabilityCol (only for logLoss).

MultilabelClassificationEvaluator(*[, ...])

Evaluator for Multilabel Classification, which expects two input columns: prediction and label.

ClusteringEvaluator(*[, predictionCol, ...])

Evaluator for Clustering results, which expects two input columns: prediction and features.

RankingEvaluator(*[, predictionCol, ...])

Evaluator for Ranking, which expects two input columns: prediction and label.

Frequency Pattern Mining#

FPGrowth(*[, minSupport, minConfidence, ...])

A parallel FP-growth algorithm to mine frequent itemsets.

FPGrowthModel([java_model])

Model fitted by FPGrowth.

PrefixSpan(*[, minSupport, ...])

A parallel PrefixSpan algorithm to mine frequent sequential patterns.

Image#

ImageSchema

Internal class for pyspark.ml.image.ImageSchema attribute.

_ImageSchema()

Internal class for pyspark.ml.image.ImageSchema attribute.

Distributor#

TorchDistributor([num_processes, ...])

A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.

DeepspeedTorchDistributor([numGpus, nnodes, ...])

Utilities#

BaseReadWrite()

Base class for MLWriter and MLReader.

DefaultParamsReadable()

Helper trait for making simple Params types readable.

DefaultParamsReader(cls)

Specialization of MLReader for Params types

DefaultParamsWritable()

Helper trait for making simple Params types writable.

DefaultParamsWriter(instance)

Specialization of MLWriter for Params types

GeneralMLWriter()

Utility class that can save ML instances in different formats.

HasTrainingSummary()

Base class for models that provides Training summary.

Identifiable()

Object with a unique ID.

MLReadable()

Mixin for instances that provide MLReader.

MLReader()

Utility class that can load ML instances.

MLWritable()

Mixin for ML instances that provide MLWriter.

MLWriter()

Utility class that can save ML instances.