Skip to content

Commit

Permalink
[SPARK-50939][ML][PYTHON][CONNECT] Support Word2Vec on Connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support Word2Vec on Connect

### Why are the changes needed?
for feature parity

### Does this PR introduce _any_ user-facing change?
yes, new algorithm supported

### How was this patch tested?
added test

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49614 from zhengruifeng/ml_connect_w2v.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 23, 2025
1 parent 83961bc commit 1a49237
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ org.apache.spark.ml.feature.MinMaxScaler
org.apache.spark.ml.feature.RobustScaler
org.apache.spark.ml.feature.StringIndexer
org.apache.spark.ml.feature.PCA
org.apache.spark.ml.feature.Word2Vec
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ org.apache.spark.ml.feature.MinMaxScalerModel
org.apache.spark.ml.feature.RobustScalerModel
org.apache.spark.ml.feature.StringIndexerModel
org.apache.spark.ml.feature.PCAModel
org.apache.spark.ml.feature.Word2VecModel
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ class Word2VecModel private[ml] (

import Word2VecModel._

private[ml] def this() = this(Identifiable.randomUID("w2v"), null)

/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
Param,
Params,
)
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.util import JavaMLReadable, JavaMLWritable, try_remote_attribute_relation
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
from pyspark.ml.common import inherit_doc

Expand Down Expand Up @@ -6381,6 +6381,7 @@ class Word2VecModel(JavaModel, _Word2VecParams, JavaMLReadable["Word2VecModel"],
"""

@since("1.5.0")
@try_remote_attribute_relation
def getVectors(self) -> DataFrame:
"""
Returns the vector representation of the words as a dataframe
Expand All @@ -6401,6 +6402,7 @@ def setOutputCol(self, value: str) -> "Word2VecModel":
return self._set(outputCol=value)

@since("1.5.0")
@try_remote_attribute_relation
def findSynonyms(self, word: Union[str, Vector], num: int) -> DataFrame:
"""
Find "num" number of words closest in similarity to "word".
Expand Down
44 changes: 44 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
VectorAssembler,
PCA,
PCAModel,
Word2Vec,
Word2VecModel,
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
from pyspark.sql import Row
Expand Down Expand Up @@ -357,6 +359,48 @@ def test_robust_scaler(self):
self.assertEqual(str(model), str(model2))
self.assertEqual(model2.getOutputCol(), "scaled")

def test_word2vec(self):
sent = ("a b " * 100 + "a c " * 10).split(" ")
df = self.spark.createDataFrame([(sent,), (sent,)], ["sentence"]).coalesce(1)

w2v = Word2Vec(vectorSize=3, seed=42, inputCol="sentence", outputCol="model")
w2v.setMaxIter(1)
self.assertEqual(w2v.getInputCol(), "sentence")
self.assertEqual(w2v.getOutputCol(), "model")
self.assertEqual(w2v.getVectorSize(), 3)
self.assertEqual(w2v.getSeed(), 42)
self.assertEqual(w2v.getMaxIter(), 1)

model = w2v.fit(df)
self.assertEqual(model.getVectors().columns, ["word", "vector"])
self.assertEqual(model.getVectors().count(), 3)

synonyms = model.findSynonyms("a", 2)
self.assertEqual(synonyms.columns, ["word", "similarity"])
self.assertEqual(synonyms.count(), 2)

# TODO(SPARK-50958): Support Word2VecModel.findSynonymsArray
# synonyms = model.findSynonymsArray("a", 2)
# self.assertEqual(len(synonyms), 2)
# self.assertEqual(synonyms[0][0], "b")
# self.assertTrue(np.allclose(synonyms[0][1], -0.024012837558984756, atol=1e-4))
# self.assertEqual(synonyms[1][0], "c")
# self.assertTrue(np.allclose(synonyms[1][1], -0.19355154037475586, atol=1e-4))

output = model.transform(df)
self.assertEqual(output.columns, ["sentence", "model"])
self.assertEqual(output.count(), 2)

# save & load
with tempfile.TemporaryDirectory(prefix="word2vec") as d:
w2v.write().overwrite().save(d)
w2v2 = Word2Vec.load(d)
self.assertEqual(str(w2v), str(w2v2))

model.write().overwrite().save(d)
model2 = Word2VecModel.load(d)
self.assertEqual(str(model), str(model2))

def test_binarizer(self):
b0 = Binarizer()
self.assertListEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ private[ml] object MLUtils {
(classOf[MaxAbsScalerModel], Set("maxAbs")),
(classOf[MinMaxScalerModel], Set("originalMax", "originalMin")),
(classOf[RobustScalerModel], Set("range", "median")),
(classOf[PCAModel], Set("pc", "explainedVariance")))
(classOf[PCAModel], Set("pc", "explainedVariance")),
(classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")))

private def validate(obj: Any, method: String): Unit = {
assert(obj != null)
Expand Down

0 comments on commit 1a49237

Please sign in to comment.