diff --git a/cpp/include/cuml/linear_model/qn_mg.hpp b/cpp/include/cuml/linear_model/qn_mg.hpp index f70fd833e9..21d35584be 100644 --- a/cpp/include/cuml/linear_model/qn_mg.hpp +++ b/cpp/include/cuml/linear_model/qn_mg.hpp @@ -63,6 +63,37 @@ void qnFit(raft::handle_t& handle, float* f, int* num_iters); +/** + * @brief support sparse vectors (Compressed Sparse Row format) for MNMG logistic regression fit + * using quasi newton methods + * @param[in] handle: the internal cuml handle object + * @param[in] input_values: vector holding non-zero values of all partitions for that rank + * @param[in] input_cols: vector holding column indices of non-zero values of all partitions for + * that rank + * @param[in] input_row_ids: vector holding row pointers of non-zero values of all partitions for + * that rank + * @param[in] X_nnz: the number of non-zero values of that rank + * @param[in] input_desc: PartDescriptor object for the input + * @param[in] labels: labels data + * @param[out] coef: learned coefficients + * @param[in] pams: model parameters + * @param[in] n_classes: number of outputs (number of classes or `1` for regression) + * @param[out] f: host pointer holding the final objective value + * @param[out] num_iters: host pointer holding the actual number of iterations taken + */ +void qnFitSparse(raft::handle_t& handle, + std::vector*>& input_values, + int* input_cols, + int* input_row_ids, + int X_nnz, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + float* coef, + const qn_params& pams, + int n_classes, + float* f, + int* num_iters); + }; // namespace opg }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/mg/glm_base_mg.cuh b/cpp/src/glm/qn/mg/glm_base_mg.cuh index 977e79f0f4..094d7197b6 100644 --- a/cpp/src/glm/qn/mg/glm_base_mg.cuh +++ b/cpp/src/glm/qn/mg/glm_base_mg.cuh @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -112,34 +113,42 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData { T* dev_scalar, cudaStream_t stream) { + raft::comms::comms_t const& communicator = raft::resource::get_comms(*(this->handle_p)); SimpleDenseMat W(wFlat.data, this->C, this->dims); SimpleDenseMat G(gradFlat.data, this->C, this->dims); SimpleVec lossVal(dev_scalar, 1); + // Ensure the same coefficients on all GPU + communicator.bcast(wFlat.data, this->C * this->dims, 0, stream); + communicator.sync_stream(stream); + // apply regularization auto regularizer_obj = this->objective; auto lossFunc = regularizer_obj->loss; auto reg = regularizer_obj->reg; G.fill(0, stream); - float reg_host = 0; + T reg_host = 0; if (reg->l2_penalty != 0) { reg->reg_grad(dev_scalar, G, W, lossFunc->fit_intercept, stream); raft::update_host(®_host, dev_scalar, 1, stream); - // note: avoid syncing here because there's a sync before reg_host is used. + raft::resource::sync_stream(*(this->handle_p)); } // apply linearFwd, getLossAndDz, linearBwd ML::GLM::detail::linearFwd( lossFunc->handle, *(this->Z), *(this->X), W); // linear part: forward pass - raft::comms::comms_t const& communicator = raft::resource::get_comms(*(this->handle_p)); - lossFunc->getLossAndDZ(dev_scalar, *(this->Z), *(this->y), stream); // loss specific part // normalize local loss before allreduce sum T factor = 1.0 * (*this->y).len / this->n_samples; raft::linalg::multiplyScalar(dev_scalar, dev_scalar, factor, 1, stream); + // GPUs calculates reg_host independently and may get values that show tiny divergence. + // Take the averaged reg_host to avoid the divergence. + T reg_factor = reg_host / this->n_ranks; + raft::linalg::addScalar(dev_scalar, dev_scalar, reg_factor, 1, stream); + communicator.allreduce(dev_scalar, dev_scalar, 1, raft::comms::op_t::SUM, stream); communicator.sync_stream(stream); @@ -154,11 +163,9 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData { communicator.allreduce(G.data, G.data, this->C * this->dims, raft::comms::op_t::SUM, stream); communicator.sync_stream(stream); - float loss_host; + T loss_host; raft::update_host(&loss_host, dev_scalar, 1, stream); raft::resource::sync_stream(*(this->handle_p)); - loss_host += reg_host; - lossVal.fill(loss_host + reg_host, stream); return loss_host; } diff --git a/cpp/src/glm/qn_mg.cu b/cpp/src/glm/qn_mg.cu index 5a60c01f79..ee75316a18 100644 --- a/cpp/src/glm/qn_mg.cu +++ b/cpp/src/glm/qn_mg.cu @@ -29,6 +29,8 @@ #include using namespace MLCommon; +#include + namespace ML { namespace GLM { namespace opg { @@ -172,6 +174,77 @@ void qnFit(raft::handle_t& handle, handle, input_data, input_desc, labels, coef, pams, X_col_major, n_classes, f, num_iters); } +template +void qnFitSparse_impl(const raft::handle_t& handle, + const qn_params& pams, + T* X_values, + I* X_cols, + I* X_row_ids, + I X_nnz, + T* y, + size_t N, + size_t D, + size_t C, + T* w0, + T* f, + int* num_iters, + size_t n_samples, + int rank, + int n_ranks) +{ + auto X_simple = SimpleSparseMat(X_values, X_cols, X_row_ids, X_nnz, N, D); + + ML::GLM::opg::qn_fit_x_mg(handle, + pams, + X_simple, + y, + C, + w0, + f, + num_iters, + n_samples, + rank, + n_ranks); // ignore sample_weight, svr_eps + return; +} + +void qnFitSparse(raft::handle_t& handle, + std::vector*>& input_values, + int* input_cols, + int* input_row_ids, + int X_nnz, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + float* coef, + const qn_params& pams, + int n_classes, + float* f, + int* num_iters) +{ + RAFT_EXPECTS(input_values.size() == 1, + "qn_mg.cu currently does not accept more than one input matrix"); + + auto data_input_values = input_values[0]; + auto data_y = labels[0]; + + qnFitSparse_impl(handle, + pams, + data_input_values->ptr, + input_cols, + input_row_ids, + X_nnz, + data_y->ptr, + input_desc.totalElementsOwnedBy(input_desc.rank), + input_desc.N, + n_classes, + coef, + f, + num_iters, + input_desc.M, + input_desc.rank, + input_desc.uniqueRanks().size()); +} + }; // namespace opg }; // namespace GLM }; // namespace ML diff --git a/python/cuml/dask/linear_model/logistic_regression.py b/python/cuml/dask/linear_model/logistic_regression.py index 38366a1b50..af53d509b1 100644 --- a/python/cuml/dask/linear_model/logistic_regression.py +++ b/python/cuml/dask/linear_model/logistic_regression.py @@ -21,6 +21,7 @@ from raft_dask.common.comms import get_raft_comm_state from dask.distributed import get_worker +from cuml.common.sparse_utils import is_sparse, has_scipy from cuml.dask.common import parts_to_ranks from cuml.dask.common.input_utils import DistributedDataHandler, concatenate from raft_dask.common.comms import Comms @@ -29,7 +30,9 @@ from cuml.internals.safe_imports import gpu_only_import cp = gpu_only_import("cupy") +cupyx = gpu_only_import("cupyx") np = cpu_only_import("numpy") +scipy = cpu_only_import("scipy") class LogisticRegression(LinearRegression): @@ -172,7 +175,20 @@ def _create_model(sessionId, datatype, **kwargs): @staticmethod def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank): - inp_X = concatenate([X for X, _ in data]) + if is_sparse(data[0][0]) is False: + inp_X = concatenate([X for X, _ in data]) + + elif has_scipy() and scipy.sparse.isspmatrix(data[0][0]): + inp_X = scipy.sparse.vstack([X for X, _ in data]) + + elif cupyx.scipy.sparse.isspmatrix(data[0][0]): + inp_X = cupyx.scipy.sparse.vstack([X for X, _ in data]) + + else: + raise ValueError( + "input matrix must be dense, scipy sparse, or cupy sparse" + ) + inp_y = concatenate([y for _, y in data]) n_ranks = max([p[0] for p in partsToSizes]) + 1 aggregated_partsToSizes = [[i, 0] for i in range(n_ranks)] diff --git a/python/cuml/linear_model/base_mg.pyx b/python/cuml/linear_model/base_mg.pyx index c13d0d2de1..3dddb74f6c 100644 --- a/python/cuml/linear_model/base_mg.pyx +++ b/python/cuml/linear_model/base_mg.pyx @@ -30,6 +30,9 @@ from cuml.common.opg_data_utils_mg cimport * from cuml.internals.input_utils import input_to_cuml_array from cuml.decomposition.utils cimport * +from cuml.common.sparse_utils import is_sparse +from cuml.internals.array_sparse import SparseCumlArray + class MGFitMixin(object): @@ -45,8 +48,10 @@ class MGFitMixin(object): :param partsToSizes: array of tuples in the format: [(rank,size)] :return: self """ + self._set_output_type(input_data[0][0]) self._set_n_features_in(n_cols) + sparse_input = is_sparse(input_data[0][0]) X_arys = [] y_arys = [] @@ -57,8 +62,14 @@ class MGFitMixin(object): else: check_dtype = self.dtype - X_m, _, self.n_cols, _ = \ - input_to_cuml_array(input_data[i][0], check_dtype=check_dtype, order=order) + if sparse_input: + + X_m = SparseCumlArray(input_data[i][0], convert_index=np.int32) + _, self.n_cols = X_m.shape + else: + X_m, _, self.n_cols, _ = \ + input_to_cuml_array(input_data[i][0], check_dtype=check_dtype, order=order) + X_arys.append(X_m) if i == 0: @@ -81,18 +92,42 @@ class MGFitMixin(object): rank_to_sizes, rank) - cdef uintptr_t X_arg = opg.build_data_t(X_arys) + cdef uintptr_t X_arg cdef uintptr_t y_arg = opg.build_data_t(y_arys) - # call inheriting class _fit that does all cython pointers and calls - self._fit(X=X_arg, - y=y_arg, - coef_ptr=coef_ptr_arg, - input_desc=part_desc) + cdef uintptr_t X_cols + cdef uintptr_t X_row_ids + + if sparse_input is False: + + X_arg = opg.build_data_t(X_arys) + + # call inheriting class _fit that does all cython pointers and calls + self._fit(X=X_arg, + y=y_arg, + coef_ptr=coef_ptr_arg, + input_desc=part_desc) + + opg.free_data_t(X_arg, self.dtype) + + else: + + assert len(X_arys) == 1, "does not support more than one sparse input matrix" + X_arg = opg.build_data_t([x.data for x in X_arys]) + X_cols = X_arys[0].indices.ptr + X_row_ids = X_arys[0].indptr.ptr + X_nnz = sum([x.nnz for x in X_arys]) + + # call inheriting class _fit that does all cython pointers and calls + self._fit(X=[X_arg, X_cols, X_row_ids, X_nnz], + y=y_arg, + coef_ptr=coef_ptr_arg, + input_desc=part_desc) + + for ary in X_arys: + del ary opg.free_rank_size_pair(rank_to_sizes) opg.free_part_descriptor(part_desc) - opg.free_data_t(X_arg, self.dtype) opg.free_data_t(y_arg, self.dtype) - return self diff --git a/python/cuml/linear_model/logistic_regression_mg.pyx b/python/cuml/linear_model/logistic_regression_mg.pyx index 3330541b32..2e96851dfa 100644 --- a/python/cuml/linear_model/logistic_regression_mg.pyx +++ b/python/cuml/linear_model/logistic_regression_mg.pyx @@ -84,6 +84,20 @@ cdef extern from "cuml/linear_model/qn_mg.hpp" namespace "ML::GLM::opg" nogil: PartDescriptor &input_desc, vector[floatData_t*] labels) except+ + cdef void qnFitSparse( + handle_t& handle, + vector[floatData_t *] input_values, + int *input_cols, + int *input_row_ids, + int X_nnz, + PartDescriptor &input_desc, + vector[floatData_t *] labels, + float *coef, + const qn_params& pams, + int n_classes, + float *f, + int *num_iters) except + + class LogisticRegressionMG(MGFitMixin, LogisticRegression): @@ -171,6 +185,7 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): def fit(self, input_data, n_rows, n_cols, parts_rank_size, rank, convert_dtype=False): + self.rank = rank assert len(input_data) == 1, f"Currently support only one (X, y) pair in the list. Received {len(input_data)} pairs." self.is_col_major = False order = 'F' if self.is_col_major else 'C' @@ -196,18 +211,42 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): cdef qn_params qnpams = self.solver_model.qnparams.params + sparse_input = True if isinstance(X, list) else False + if self.dtype == np.float32: - qnFit( - handle_[0], - deref(X), - deref(input_desc), - deref(y), - mat_coef_ptr, - qnpams, - self.is_col_major, - self._num_classes, - &objective32, - &num_iters) + if sparse_input is False: + qnFit( + handle_[0], + deref(X), + deref(input_desc), + deref(y), + mat_coef_ptr, + qnpams, + self.is_col_major, + self._num_classes, + &objective32, + &num_iters) + + else: + assert len(X) == 4 + X_values = X[0] + X_cols = X[1] + X_row_ids = X[2] + X_nnz = X[3] + + qnFitSparse( + handle_[0], + deref(X_values), + X_cols, + X_row_ids, + X_nnz, + deref(input_desc), + deref(y), + mat_coef_ptr, + qnpams, + self._num_classes, + &objective32, + &num_iters) self.solver_model.objective = objective32 diff --git a/python/cuml/tests/dask/test_dask_logistic_regression.py b/python/cuml/tests/dask/test_dask_logistic_regression.py index 4f0cd7408b..d9e27f63af 100644 --- a/python/cuml/tests/dask/test_dask_logistic_regression.py +++ b/python/cuml/tests/dask/test_dask_logistic_regression.py @@ -21,6 +21,7 @@ from sklearn.linear_model import LogisticRegression as skLR from cuml.internals.safe_imports import cpu_only_import from cuml.testing.utils import array_equal +from scipy.sparse import csr_matrix pd = cpu_only_import("pandas") np = cpu_only_import("numpy") @@ -48,6 +49,38 @@ def _prep_training_data(c, X_train, y_train, partitions_per_worker): return X_train_df, y_train_df +def _prep_training_data_sparse(c, X_train, y_train, partitions_per_worker): + "The implementation follows test_dask_tfidf.create_cp_sparse_dask_array" + import dask.array as da + + workers = c.has_what().keys() + target_n_partitions = partitions_per_worker * len(workers) + + def cal_chunks(dataset, n_partitions): + + n_samples = dataset.shape[0] + n_samples_per_part = int(n_samples / n_partitions) + chunk_sizes = [n_samples_per_part] * n_partitions + samples_last_row = n_samples - ( + (n_partitions - 1) * n_samples_per_part + ) + chunk_sizes[-1] = samples_last_row + return tuple(chunk_sizes) + + assert ( + X_train.shape[0] == y_train.shape[0] + ), "the number of data records is not equal to the number of labels" + target_chunk_sizes = cal_chunks(X_train, target_n_partitions) + + X_da = da.from_array(X_train, chunks=(target_chunk_sizes, -1)) + y_da = da.from_array(y_train, chunks=target_chunk_sizes) + + X_da, y_da = dask_utils.persist_across_workers( + c, [X_da, y_da], workers=workers + ) + return X_da, y_da + + def make_classification_dataset(datatype, nrows, ncols, n_info, n_classes=2): X, y = make_classification( n_samples=nrows, @@ -285,6 +318,7 @@ def test_lbfgs( l1_ratio=None, C=1.0, n_classes=2, + convert_to_sparse=False, ): tolerance = 0.005 @@ -305,7 +339,12 @@ def imp(): datatype, nrows, ncols, n_info, n_classes=n_classes ) - X_df, y_df = _prep_training_data(client, X, y, n_parts) + if convert_to_sparse is False: + # X_dask and y_dask are dask cudf + X_dask, y_dask = _prep_training_data(client, X, y, n_parts) + else: + # X_dask and y_dask are dask array + X_dask, y_dask = _prep_training_data_sparse(client, X, y, n_parts) lr = cumlLBFGS_dask( solver="qn", @@ -315,9 +354,19 @@ def imp(): C=C, verbose=True, ) - lr.fit(X_df, y_df) - lr_coef = lr.coef_.to_numpy() - lr_intercept = lr.intercept_.to_numpy() + lr.fit(X_dask, y_dask) + + def array_to_numpy(ary): + if isinstance(ary, cp.ndarray): + return cp.asarray(ary) + elif isinstance(ary, cudf.DataFrame) or isinstance(ary, cudf.Series): + return ary.to_numpy() + else: + assert isinstance(ary, np.ndarray) + return ary + + lr_coef = array_to_numpy(lr.coef_) + lr_intercept = array_to_numpy(lr.intercept_) if penalty == "l2" or penalty == "none": sk_solver = "lbfgs" @@ -345,7 +394,11 @@ def imp(): ) # test predict - cu_preds = lr.predict(X_df, delayed=delayed).compute().to_numpy() + cu_preds = lr.predict(X_dask, delayed=delayed).compute() + if isinstance(cu_preds, cp.ndarray): + cu_preds = cp.asnumpy(cu_preds) + if not isinstance(cu_preds, np.ndarray): + cu_preds = cu_preds.to_numpy() accuracy_cuml = accuracy_score(y, cu_preds) sk_preds = sk_model.predict(X) @@ -491,3 +544,80 @@ def test_elasticnet( strength = 1.0 / lr.C assert l1_strength == lr.l1_ratio * strength assert l2_strength == (1.0 - lr.l1_ratio) * strength + + +@pytest.mark.mg +@pytest.mark.parametrize("fit_intercept", [False, True]) +@pytest.mark.parametrize( + "regularization", + [ + ("none", 1.0, None), + ("l2", 2.0, None), + ("l1", 2.0, None), + ("elasticnet", 2.0, 0.2), + ], +) +@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize("delayed", [True]) +@pytest.mark.parametrize("n_classes", [2, 8]) +def test_sparse_from_dense( + fit_intercept, regularization, datatype, delayed, n_classes, client +): + penalty = regularization[0] + C = regularization[1] + l1_ratio = regularization[2] + + test_lbfgs( + nrows=1e5, + ncols=20, + n_parts=2, + fit_intercept=fit_intercept, + datatype=datatype, + delayed=delayed, + client=client, + penalty=penalty, + n_classes=n_classes, + C=C, + l1_ratio=l1_ratio, + convert_to_sparse=True, + ) + + +@pytest.mark.parametrize("dtype", [np.float32]) +def test_sparse_nlp20news(dtype, nlp_20news, client): + + X, y = nlp_20news + n_parts = 2 # partitions_per_worker + + from scipy.sparse import csr_matrix + from sklearn.model_selection import train_test_split + + X = X.astype(dtype) + + X = csr_matrix(X) + y = y.get().astype(dtype) + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + from cuml.dask.linear_model import LogisticRegression as MG + + X_train_da, y_train_da = _prep_training_data_sparse( + client, X_train, y_train, partitions_per_worker=n_parts + ) + X_test_da, _ = _prep_training_data_sparse( + client, X_test, y_test, partitions_per_worker=n_parts + ) + + cumg = MG(verbose=6, C=20.0) + cumg.fit(X_train_da, y_train_da) + + preds = cumg.predict(X_test_da).compute() + cuml_score = accuracy_score(y_test, preds.tolist()) + + from sklearn.linear_model import LogisticRegression as CPULR + + cpu = CPULR(C=20.0) + cpu.fit(X_train, y_train) + cpu_preds = cpu.predict(X_test) + cpu_score = accuracy_score(y_test, cpu_preds.tolist()) + assert cuml_score >= cpu_score or np.abs(cuml_score - cpu_score) < 1e-3