GaussianMixtureModel#

class pyspark.mllib.clustering.GaussianMixtureModel(java_model)[source]#

A clustering model derived from the Gaussian Mixture Model method.

New in version 1.3.0.

Examples

>>> from pyspark.mllib.linalg import Vectors, DenseMatrix
>>> from numpy.testing import assert_equal
>>> from shutil import rmtree
>>> import os, tempfile
>>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
...                                         0.9,0.8,0.75,0.935,
...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2)
>>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
...                                 maxIterations=50, seed=10)
>>> labels = model.predict(clusterdata_1).collect()
>>> labels[0]==labels[1]
False
>>> labels[1]==labels[2]
False
>>> labels[4]==labels[5]
True
>>> model.predict([-0.1,-0.05])
0
>>> softPredicted = model.predictSoft([-0.1,-0.05])
>>> abs(softPredicted[0] - 1.0) < 0.03
True
>>> abs(softPredicted[1] - 0.0) < 0.03
True
>>> abs(softPredicted[2] - 0.0) < 0.03
True
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = GaussianMixtureModel.load(sc, path)
>>> assert_equal(model.weights, sameModel.weights)
>>> mus, sigmas = list(
...     zip(*[(g.mu, g.sigma) for g in model.gaussians]))
>>> sameMus, sameSigmas = list(
...     zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
>>> mus == sameMus
True
>>> sigmas == sameSigmas
True
>>> from shutil import rmtree
>>> try:
...     rmtree(path)
... except OSError:
...     pass
>>> data =  array([-5.1971, -2.5359, -3.8220,
...                -5.2211, -5.0602,  4.7118,
...                 6.8989, 3.4592,  4.6322,
...                 5.7048,  4.6567, 5.5026,
...                 4.5605,  5.2043,  6.2734])
>>> clusterdata_2 = sc.parallelize(data.reshape(5,3))
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
...                               maxIterations=150, seed=4)
>>> labels = model.predict(clusterdata_2).collect()
>>> labels[0]==labels[1]
True
>>> labels[2]==labels[3]==labels[4]
True

Methods

call(name, *a)

Call method of java_model

load(sc, path)

Load the GaussianMixtureModel from disk.

predict(x)

Find the cluster to which the point 'x' or each point in RDD 'x' has maximum membership in this model.

predictSoft(x)

Find the membership of point 'x' or each point in RDD 'x' to all mixture components.

save(sc, path)

Save this model to the given path.

Attributes

gaussians

Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i.

k

Number of gaussians in mixture.

weights

Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1.

Methods Documentation

call(name, *a)#

Call method of java_model

classmethod load(sc, path)[source]#

Load the GaussianMixtureModel from disk.

New in version 1.5.0.

Parameters
scSparkContext
pathstr

Path to where the model is stored.

predict(x)[source]#

Find the cluster to which the point ‘x’ or each point in RDD ‘x’ has maximum membership in this model.

New in version 1.3.0.

Parameters
xpyspark.mllib.linalg.Vector or pyspark.RDD

A feature vector or an RDD of vectors representing data points.

Returns
numpy.float64 or pyspark.RDD of int

Predicted cluster label or an RDD of predicted cluster labels if the input is an RDD.

predictSoft(x)[source]#

Find the membership of point ‘x’ or each point in RDD ‘x’ to all mixture components.

New in version 1.3.0.

Parameters
xpyspark.mllib.linalg.Vector or pyspark.RDD

A feature vector or an RDD of vectors representing data points.

Returns
numpy.ndarray or pyspark.RDD

The membership value to all mixture components for vector ‘x’ or each vector in RDD ‘x’.

save(sc, path)#

Save this model to the given path.

New in version 1.3.0.

Attributes Documentation

gaussians#

Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i.

New in version 1.4.0.

k#

Number of gaussians in mixture.

New in version 1.4.0.

weights#

Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1.

New in version 1.4.0.