From 6098f20553b6a52ee47e812247defb94d28b67c0 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Sun, 26 Jan 2025 11:22:41 +0800 Subject: [PATCH] [JVM-Packages] Allow XGBoost jvm package run on GPU without rapids --- .../scala/spark/GpuXGBoostPlugin.scala | 2 +- .../scala/spark/GpuXGBoostPluginSuite.scala | 33 ++++++++++--------- .../scala/spark/XGBoostEstimator.scala | 21 ++++++------ 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala index 6ab9f679d706..36e5b7da4299 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2024 by Contributors + Copyright (c) 2024-2025 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala index a5ff2ba0f589..d4a24f7745c5 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala @@ -90,24 +90,23 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { val df = Seq((1.0f, 2.0f, 0.0f), (2.0f, 3.0f, 1.0f) ).toDF("c1", "c2", "label") - val classifier = new XGBoostClassifier() - assert(classifier.getPlugin.isDefined) - assert(classifier.getPlugin.get.isEnabled(df) === expected) + assert(PluginUtils.getPlugin.isDefined) + assert(PluginUtils.getPlugin.get.isEnabled(df) === expected) } // spark.rapids.sql.enabled is not set explicitly, default to true withSparkSession(new SparkConf(), spark => { - checkIsEnabled(spark, true) + checkIsEnabled(spark, expected = true) }) // set spark.rapids.sql.enabled to false withCpuSparkSession() { spark => - checkIsEnabled(spark, false) + checkIsEnabled(spark, expected = false) } // set spark.rapids.sql.enabled to true withGpuSparkSession() { spark => - checkIsEnabled(spark, true) + checkIsEnabled(spark, expected = true) } } @@ -122,7 +121,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { ).toDF("c1", "c2", "weight", "margin", "label", "other") val classifier = new XGBoostClassifier() - val plugin = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin] + val plugin = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin] intercept[IllegalArgumentException] { plugin.validate(classifier, df) } @@ -156,9 +155,9 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { var classifier = new XGBoostClassifier() .setNumWorkers(3) .setFeaturesCol(features) - assert(classifier.getPlugin.isDefined) - assert(classifier.getPlugin.get.isInstanceOf[GpuXGBoostPlugin]) - var out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin] + assert(PluginUtils.getPlugin.isDefined) + assert(PluginUtils.getPlugin.get.isInstanceOf[GpuXGBoostPlugin]) + var out = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin] .preprocess(classifier, df) assert(out.schema.names.contains("c1") && out.schema.names.contains("c2")) @@ -172,7 +171,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { .setWeightCol("weight") .setBaseMarginCol("margin") .setDevice("cuda") - out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin] + out = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin] .preprocess(classifier, df) assert(out.schema.names.contains("c1") && out.schema.names.contains("c2")) @@ -207,7 +206,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { .setDevice("cuda") .setMissing(missing) - val rdd = classifier.getPlugin.get.buildRddWatches(classifier, df) + val rdd = PluginUtils.getPlugin.get.buildRddWatches(classifier, df) val result = rdd.mapPartitions { iter => val watches = iter.next() val size = watches.size @@ -271,7 +270,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { .setMissing(missing) .setEvalDataset(eval) - val rdd = classifier.getPlugin.get.buildRddWatches(classifier, train) + val rdd = PluginUtils.getPlugin.get.buildRddWatches(classifier, train) val result = rdd.mapPartitions { iter => val watches = iter.next() val size = watches.size @@ -324,7 +323,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { .setLabelCol("label") .setDevice("cuda") - assert(estimator.getPlugin.isDefined && estimator.getPlugin.get.isEnabled(df)) + assert(PluginUtils.getPlugin.isDefined && PluginUtils.getPlugin.get.isEnabled(df)) val out = estimator.fit(df).transform(df) // Transform should not discard the other columns of the transforming dataframe @@ -528,7 +527,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { .setGroupCol(group) .setDevice("cuda") - val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df) + val processedDf = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin] + .preprocess(ranker, df) processedDf.rdd.foreachPartition { iter => { var prevGroup = Int.MinValue while (iter.hasNext) { @@ -575,7 +575,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { // The fix has replaced repartition with repartitionByRange which will put the // instances with same group into the same partition val ranker = new XGBoostRanker().setGroupCol("group").setNumWorkers(num_workers) - val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df) + val processedDf = PluginUtils.getPlugin.get.asInstanceOf[GpuXGBoostPlugin] + .preprocess(ranker, df) val rows = processedDf .select("group") .mapPartitions { case iter => diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 0bfbf5ad2599..9f9fc22755fd 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2024 by Contributors + Copyright (c) 2024-2025 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoost } } -private[spark] trait PluginMixin { +private[spark] object PluginUtils { // Find the XGBoostPlugin by ServiceLoader private val plugin: Option[XGBoostPlugin] = { val classLoader = Option(Thread.currentThread().getContextClassLoader) @@ -85,9 +85,9 @@ private[spark] trait PluginMixin { } /** Visible for testing */ - protected[spark] def getPlugin: Option[XGBoostPlugin] = plugin + def getPlugin: Option[XGBoostPlugin] = plugin - protected def isPluginEnabled(dataset: Dataset[_]): Boolean = { + def isPluginEnabled(dataset: Dataset[_]): Boolean = { plugin.map(_.isEnabled(dataset)).getOrElse(false) } } @@ -95,8 +95,7 @@ private[spark] trait PluginMixin { private[spark] trait XGBoostEstimator[ Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M] with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner] - with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable - with PluginMixin { + with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable { protected val logger = LogFactory.getLog("XGBoostSpark") @@ -428,8 +427,8 @@ private[spark] trait XGBoostEstimator[ protected def train(dataset: Dataset[_]): M = { validate(dataset) - val rdd = if (isPluginEnabled(dataset)) { - getPlugin.get.buildRddWatches(this, dataset) + val rdd = if (PluginUtils.isPluginEnabled(dataset)) { + PluginUtils.getPlugin.get.buildRddWatches(this, dataset) } else { val (input, columnIndexes) = preprocess(dataset) toRdd(input, columnIndexes) @@ -466,7 +465,7 @@ private[spark] case class PredictedColumns( * XGBoost base model */ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with MLWritable - with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] with PluginMixin { + with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] { protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col" @@ -597,8 +596,8 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML } override def transform(dataset: Dataset[_]): DataFrame = { - if (getPlugin.isDefined) { - return getPlugin.get.transform(this, dataset) + if (PluginUtils.isPluginEnabled(dataset)) { + return PluginUtils.getPlugin.get.transform(this, dataset) } validateFeatureType(dataset.schema) val (schema, pred) = preprocess(dataset)