Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Jan 23, 2025
1 parent 1a49237 commit 6e9793c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml non-model transformer.
# So register the supported transformer here if you're trying to add a new one.
########### Transformers
org.apache.spark.ml.feature.DCT
org.apache.spark.ml.feature.VectorAssembler

########### Model for loading
Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

from pyspark.ml.feature import (
DCT,
Binarizer,
CountVectorizer,
CountVectorizerModel,
Expand Down Expand Up @@ -56,6 +57,34 @@


class FeatureTestsMixin:
def test_dct(self):
df = self.spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])
dct = DCT()
dct.setInverse(False)
dct.setInputCol("vec")
dct.setOutputCol("resultVec")

self.assertFalse(dct.getInverse())
self.assertEqual(dct.getInputCol(), "vec")
self.assertEqual(dct.getOutputCol(), "resultVec")

output = dct.transform(df)
self.assertEqual(output.columns, ["vec", "resultVec"])
self.assertEqual(output.count(), 1)
self.assertTrue(
np.allclose(
output.head().resultVec.toArray(),
[10.96965511, -0.70710678, -2.04124145],
atol=1e-4,
)
)

# save & load
with tempfile.TemporaryDirectory(prefix="dct") as d:
dct.write().overwrite().save(d)
dct2 = DCT.load(d)
self.assertEqual(str(dct), str(dct2))

def test_string_indexer(self):
df = (
self.spark.createDataFrame(
Expand Down

0 comments on commit 6e9793c

Please sign in to comment.