Skip to content

Commit

Permalink
extend concat_and_free to handle sparse vectors
Browse files Browse the repository at this point in the history
revise per comments

replaced sklearn countvectorizer with spark countvectorizer and regextokenizer
  • Loading branch information
lijinf2 committed Jan 3, 2024
1 parent 9b0b20c commit d4b98c6
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 59 deletions.
11 changes: 3 additions & 8 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ class LogisticRegression(
If None, use dense array if the first VectorUDT of a dataframe is DenseVector. Use sparse array if it is SparseVector.
If False, always uses dense array. This is favorable if the majority of VectorUDT vectors are DenseVector.
If True, always uses sparse array. This is favorable if the majority of the VectorUDT vectors are SparseVector.
Note this is only supported in spark >= 3.4.
fitIntercept:
Whether to fit an intercept term.
num_workers:
Expand Down Expand Up @@ -940,14 +941,8 @@ def _logistic_regression_fit(
concated = pd.concat(X_list)
concated_y = pd.concat(y_list)
else:
if isinstance(X_list[0], scipy.sparse.csr_matrix):
concated = scipy.sparse.vstack(X_list)
elif isinstance(X_list[0], cupyx.scipy.sparse.csr_matrix):
concated = cupyx.scipy.sparse.vstack(X_list)
else:
# features are either cp or np arrays here
concated = _concat_and_free(X_list, order=array_order)

# features are either cp, np, scipy csr or cupyx csr arrays here
concated = _concat_and_free(X_list, order=array_order)
concated_y = _concat_and_free(y_list, order=array_order)

pdesc = PartitionDescriptor.build(
Expand Down
10 changes: 8 additions & 2 deletions python/src/spark_rapids_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:
# - `size`: int
# - `indices`: array<int>
# - `values`: array<double>
# For sparse vector, `type` field is 0, `size` field means vector length,
# For sparse vector, `type` field is 0, `size` field means vector dimension,
# `indices` field is the array of active element indices, `values` field
# is the array of active element values.
# For dense vector, `type` field is 1, `size` and `indices` fields are None,
Expand Down Expand Up @@ -213,12 +213,16 @@ def _read_csr_matrix_from_unwrapped_spark_vec(part: pd.DataFrame) -> csr_matrix:

if n_features == 0:
n_features = vec_size
assert n_features == vec_size
assert n_features == vec_size, "all vectors must be of the same dimension"

csr_indices_list.append(csr_indices)
csr_indptr_list.append(csr_indptr_list[-1] + len(csr_indices))
assert len(csr_indptr_list) == 1 + len(csr_indices_list)

csr_values_list.append(csr_values)

assert len(csr_indptr_list) == 1 + len(part)

csr_indptr_arr = np.array(csr_indptr_list)
csr_indices_arr = np.concatenate(csr_indices_list)
csr_values_arr = np.concatenate(csr_values_list)
Expand Down Expand Up @@ -667,6 +671,8 @@ def _call_cuml_fit_func(
cuml_verbose = self.cuml_params.get("verbose", False)
use_sparse_array = (
alias.featureVectorType in dataset.schema.fieldNames()
and alias.featureVectorSize in dataset.schema.fieldNames()
and alias.featureVectorIndices in dataset.schema.fieldNames()
) # use sparse array in cuml only if features vectorudt column was unwrapped

(enable_nccl, require_ucx) = self._require_nccl_ucx()
Expand Down
13 changes: 8 additions & 5 deletions python/src/spark_rapids_ml/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@ class HasEnableSparseDataOptim(Params):

"""
This is a Params based class inherited from XGBOOST: https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/spark/params.py.
It holds the variable to store the boolean config of enabling sparse data optimization.
It holds the variable to store the boolean config for enabling sparse data optimization.
"""

enable_sparse_data_optim = Param(
Params._dummy(),
"enable_sparse_data_optim",
"This stores the boolean config of enabling sparse data optimization, if enabled, "
"Spark rapids ml will construct a sparse matrix in the the format of csr_matrix, "
"then calls cuml with the sparse matrix. This config is disabled by default. If most of "
"examples in your training dataset are sparse vectors, we suggest to enable this config.",
"This param activates sparse data optimization for VectorUDT features column. "
"If the param is not included in an Estimator class, "
"Spark rapids ml always converts VectorUDT features column into dense arrays when calling cuml backend. "
"If included, Spark rapids ml will determine whether to create sparse arrays based on the param value: "
"(1) If None, create dense arrays if the first VectorUDT of a dataframe is DenseVector. Create sparse arrays if it is SparseVector."
"(2) If False, create dense arrays. This is favorable if the majority of vectors are DenseVector."
"(3) If True, create sparse arrays. This is favorable if the majority of the VectorUDT vectors are SparseVector.",
typeConverter=TypeConverters.toBoolean,
)

Expand Down
46 changes: 31 additions & 15 deletions python/src/spark_rapids_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
import cudf
import cupy as cp

import cupyx
import numpy as np
import pandas as pd
import scipy
from pyspark import BarrierTaskContext, SparkConf, SparkContext, TaskContext
from pyspark.sql import Column, SparkSession
from pyspark.sql.types import ArrayType, FloatType
from scipy.sparse import csr_matrix

_ArrayOrder = Literal["C", "F"]

Expand Down Expand Up @@ -200,26 +201,41 @@ def build(cls, partition_rows: List[int], total_cols: int) -> "PartitionDescript


def _concat_and_free(
array_list: Union[List["cp.ndarray"], List[np.ndarray]], order: _ArrayOrder = "F"
) -> Union["cp.ndarray", np.ndarray]:
array_list: Union[
List["cp.ndarray"],
List[np.ndarray],
List[scipy.sparse.csr_matrix],
List["cupyx.scipy.sparse.csr_matrix"],
],
order: _ArrayOrder = "F",
) -> Union[
"cp.ndarray", np.ndarray, scipy.sparse.csr_matrix, "cupyx.scipy.sparse.csr_matrix"
]:
"""
concatenates a list of compatible numpy arrays into a 'order' ordered output array,
in a memory efficient way.
Note: frees list elements so do not reuse after calling.
if the type of input arrays is scipy or cupyx csr_matrix, 'order' parameter will not be used.
"""
import cupy as cp
if isinstance(array_list[0], scipy.sparse.csr_matrix):
concated = scipy.sparse.vstack(array_list)
elif isinstance(array_list[0], cupyx.scipy.sparse.csr_matrix):
concated = cupyx.scipy.sparse.vstack(array_list)
else:
import cupy as cp

array_module = cp if isinstance(array_list[0], cp.ndarray) else np
array_module = cp if isinstance(array_list[0], cp.ndarray) else np

rows = sum(arr.shape[0] for arr in array_list)
if len(array_list[0].shape) > 1:
cols = array_list[0].shape[1]
concat_shape: Tuple[int, ...] = (rows, cols)
else:
concat_shape = (rows,)
d_type = array_list[0].dtype
concated = array_module.empty(shape=concat_shape, order=order, dtype=d_type)
array_module.concatenate(array_list, out=concated)
rows = sum(arr.shape[0] for arr in array_list)
if len(array_list[0].shape) > 1:
cols = array_list[0].shape[1]
concat_shape: Tuple[int, ...] = (rows, cols)
else:
concat_shape = (rows,)
d_type = array_list[0].dtype
concated = array_module.empty(shape=concat_shape, order=order, dtype=d_type)
array_module.concatenate(array_list, out=concated)
del array_list[:]
return concated

Expand Down Expand Up @@ -452,7 +468,7 @@ def translate_trees(sc: SparkContext, impurity: str, model: Dict[str, Any]): #
# to the XGBOOST _get_unwrap_udt_fn in https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/spark/core.py
def _get_unwrap_udt_fn() -> Callable[[Union[Column, str]], Column]:
try:
from pyspark.sql.functions import unwrap_udt
from pyspark.sql.functions import unwrap_udt # type: ignore

return unwrap_udt
except ImportError:
Expand Down
84 changes: 55 additions & 29 deletions python/tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,14 @@ def test_compat_sparse_binomial(
gpu_lr = LogisticRegression(**params)
assert gpu_lr.hasParam("enable_sparse_data_optim") is True
assert gpu_lr.getOrDefault("enable_sparse_data_optim") == None

if version.parse(pyspark.__version__) < version.parse("3.4.0"):
err_msg = "Cannot import pyspark `unwrap_udt` function. Please install pyspark>=3.4 "
"or run on Databricks Runtime."
with pytest.raises(RuntimeError, match=err_msg):
gpu_lr.fit(bdf)
return

gpu_model = gpu_lr.fit(bdf)

cpu_lr = SparkLogisticRegression(**params)
Expand Down Expand Up @@ -1413,6 +1421,14 @@ def test_compat_sparse_multinomial(
gpu_lr = LogisticRegression(**params)
assert gpu_lr.hasParam("enable_sparse_data_optim") is True
assert gpu_lr.getOrDefault("enable_sparse_data_optim") == None

if version.parse(pyspark.__version__) < version.parse("3.4.0"):
err_msg = "Cannot import pyspark `unwrap_udt` function. Please install pyspark>=3.4 "
"or run on Databricks Runtime."
with pytest.raises(RuntimeError, match=err_msg):
gpu_lr.fit(mdf)
return

gpu_model = gpu_lr.fit(mdf)

cpu_lr = SparkLogisticRegression(**params)
Expand All @@ -1432,51 +1448,51 @@ def test_sparse_nlp20news(
fit_intercept: bool,
caplog: LogCaptureFixture,
) -> None:
datatype = np.float32
if version.parse(pyspark.__version__) < version.parse("3.4.0"):
import logging

err_msg = (
"pyspark < 3.4 is detected. Cannot import pyspark `unwrap_udt` function. "
)
"The test case will be skipped. Please install pyspark>=3.4."
logging.info(err_msg)
return

tolerance = 0.001
reg_param = 1e-6
reg_param = 1e-2

from scipy.sparse import csr_matrix
from pyspark.ml.feature import CountVectorizer, RegexTokenizer
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split

try:
twenty_train = fetch_20newsgroups(subset="train", shuffle=True, random_state=42)
except:
pytest.xfail(reason="Error fetching 20 newsgroup dataset")

count_vect = CountVectorizer()
X = count_vect.fit_transform(twenty_train.data)
y = twenty_train.target

X = X.astype(datatype)
y = y.astype(datatype).tolist()

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
X = twenty_train.data
y = twenty_train.target.tolist()

conf = {
"spark.rapids.ml.uvm.enabled": True
} # enable memory management to run the test case on GPU with small memory (e.g. 2G)
with CleanSparkSession(conf) as spark:
data = [
Row(
label=y[i],
weight=1.0,
text=X[i],
)
for i in range(len(X))
]
df = spark.createDataFrame(data)
tokenizer = RegexTokenizer(inputCol="text", outputCol="tokens")
df = tokenizer.transform(df)

def to_df(X_csr: csr_matrix, y_ary: List[float]) -> DataFrame:
assert X_csr.shape[0] == len(y_ary)
dimension = X_csr.shape[1]
data = [
Row(
label=y_ary[i],
weight=1.0,
features=Vectors.sparse(dimension, X_csr[i].indices, X_csr[i].data),
)
for i in range(X_csr.shape[0])
]

df = spark.createDataFrame(data)
return df
cv = CountVectorizer(inputCol="tokens", outputCol="features")
cv_model = cv.fit(df)
df = cv_model.transform(df)

df_train = to_df(X_train, y_train)
df_test = to_df(X_test, y_test)
df_train, df_test = df.randomSplit([0.8, 0.2])

gpu_lr = LogisticRegression(
enable_sparse_data_optim=True,
Expand Down Expand Up @@ -1528,6 +1544,16 @@ def test_quick_sparse(
n_classes: int,
gpu_number: int,
) -> None:
if version.parse(pyspark.__version__) < version.parse("3.4.0"):
import logging

err_msg = (
"pyspark < 3.4 is detected. Cannot import pyspark `unwrap_udt` function. "
)
"The test case will be skipped. Please install pyspark>=3.4."
logging.info(err_msg)
return

convert_to_sparse = True
tolerance = 0.005
reg_param = reg_factors[0]
Expand Down

0 comments on commit d4b98c6

Please sign in to comment.