Skip to content

Commit

Permalink
[wip][jvm-packages] Add java class for ExtMemQdm.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 21, 2025
1 parent c2cfd82 commit 57d0108
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2025, XGBoost Contributors
*/
package ml.dmlc.xgboost4j.java;

import java.util.Map;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.core.JsonProcessingException;

import java.util.Iterator;

public class ExtMemQuantileDMatrix extends QuantileDMatrix {
// on_host is set to true by default as we only support GPU at the moment
// cache_prefix is not used yet since we have on_host=true.
public ExtMemQuantileDMatrix(Iterator<ColumnBatch> iter,
float missing,
int maxBin,
DMatrix ref,
int nthread,
int max_num_device_pages,
int max_quantile_batches,
int min_cache_page_bytes) throws XGBoostError {
long[] out = new long[1];
long[] ref_handle = null;
if (ref != null) {
ref_handle = new long[1];
ref_handle[0] = ref.getHandle();
}
String conf = this.getConfig(missing, maxBin, nthread, max_num_device_pages, max_quantile_batches,
min_cache_page_bytes);
XGBoostJNI.checkCall(XGBoostJNI.XGExtMemQuantileDMatrixCreateFromCallback(
iter, ref_handle, conf, out));
handle = out[0];
}

private String getConfig(float missing, int maxBin, int nthread, int max_num_device_pages,
int max_quantile_batches,
int min_cache_page_bytes) {
Map<String, Object> conf = new java.util.HashMap<>();
conf.put("missing", missing);
conf.put("max_bin", maxBin);
conf.put("nthread", nthread);
conf.put("max_num_device_pages", max_num_device_pages);
conf.put("max_quantile_batches", max_quantile_batches);
conf.put("min_cache_page_bytes", min_cache_page_bytes);
conf.put("on_host", true);
conf.put("cache_prefix", ".");
ObjectMapper mapper = new ObjectMapper();

// Handle NaN values. Jackson by default serializes NaN values into strings.
SimpleModule module = new SimpleModule();
module.addSerializer(Double.class, new F64NaNSerializer());
module.addSerializer(Float.class, new F32NaNSerializer());
mapper.registerModule(module);

try {
String config = mapper.writeValueAsString(conf);
return config;
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize configuration", e);
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ public void serialize(Float value, JsonGenerator gen,
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
// implicit constructor for the ext mem version of the QDM.
protected QuantileDMatrix() {
}

/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);

public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);

public final static native int XGExtMemQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, long[] ref, String config, long[] out);

public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);
Expand Down
4 changes: 1 addition & 3 deletions jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
#include "../../../../src/common/common.h"

namespace xgboost::jni {
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, jobject jdata_iter,
jlongArray jref, char const *config,
jlongArray jout) {
int QdmFromCallback(JNIEnv *, jobject, jlongArray, char const, bool, jlongArray) {
API_BEGIN();
common::AssertGPUSupport();
API_END();
Expand Down
105 changes: 93 additions & 12 deletions jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
#include <jni.h>
#include <xgboost/c_api.h>

#include "../../../../src/common/common.h"
#include "../../../../src/common/cuda_pinned_allocator.h"
#include "../../../../src/common/device_vector.cuh" // for device_vector
#include "../../../../src/common/json_utils.h"
#include "../../../../src/data/array_interface.h"
#include "jvm_utils.h" // for CheckJvmCall

Expand Down Expand Up @@ -395,26 +397,96 @@ class DataIteratorProxy {
}
return NextSecondLoop();
}
};
}
};

// An iterator proxy for external memory.
class ExtMemIteratorProxy {
JvmIter jiter_;
DMatrixProxy proxy_;

public:
explicit ExtMemIteratorProxy(jobject jiter) : jiter_(jiter) {}

~ExtMemIteratorProxy() = default;

DMatrixHandle GetDMatrixHandle() const { return proxy_.GetDMatrixHandle(); }

void SetArrayInterface(std::string interface_str) {
auto json_interface = Json::Load({interface_str.c_str(), interface_str.size()});
CHECK(!IsA<Null>(json_interface));

std::string str;
Json features = json_interface["features"];
proxy_.SetData(features);

// set the meta info.
auto json_map = get<Object const>(json_interface);
if (json_map.find(Symbols::kLabel) == json_map.cend()) {
LOG(FATAL) << "Must have a label field.";
}
Json label = json_interface[Symbols::kLabel.c_str()];
CHECK(!IsA<Null>(label));
proxy_.SetInfo(Symbols::kLabel, label);

if (json_map.find(Symbols::kWeight) != json_map.cend()) {
Json weight = json_interface[Symbols::kWeight.c_str()];
CHECK(!IsA<Null>(weight));
proxy_.SetInfo(Symbols::kWeight, weight);
}

if (json_map.find(Symbols::kBaseMargin) != json_map.cend()) {
Json basemargin = json_interface[Symbols::kBaseMargin.c_str()];
proxy_.SetInfo("base_margin", basemargin);
}

if (json_map.find(Symbols::kQid) != json_map.cend()) {
Json qid = json_interface[Symbols::kQid.c_str()];
proxy_.SetInfo(Symbols::kQid, qid);
}
}

int Next() {
try {
if (this->jiter_.PullIterFromJVM(
[this](char const *cjaif) { this->SetArrayInterface(cjaif); })) {
return 1;
} else {
return 0;
}
} catch (dmlc::Error const &e) {
if (jiter_.Status() == JNI_EDETACHED) {
GlobalJvm()->DetachCurrentThread();
}
LOG(FATAL) << e.what();
}
return 0;
}

void Reset() { this->jiter_.CloseJvmBatch(); }
};

namespace {
void Reset(DataIterHandle self) {
static_cast<xgboost::jni::DataIteratorProxy *>(self)->Reset();
}
void Reset(DataIterHandle self) { static_cast<xgboost::jni::DataIteratorProxy *>(self)->Reset(); }

int Next(DataIterHandle self) {
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
}

void ExternalMemoryReset(DataIterHandle self) {
static_cast<xgboost::jni::ExtMemIteratorProxy *>(self)->Reset();
}

int ExternalMemoryNext(DataIterHandle self) {
return static_cast<xgboost::jni::ExtMemIteratorProxy *>(self)->Next();
}

template <typename T>
using Deleter = std::function<void(T *)>;
} // anonymous namespace
} // anonymous namespace

XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass, jobject jdata_iter,
jlongArray jref, char const *config,
jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
int QdmFromCallback(JNIEnv *jenv, jobject jdata_iter, jlongArray jref, char const *config,
bool is_extmem, jlongArray jout) {
DMatrixHandle result;
DMatrixHandle ref{nullptr};

Expand All @@ -427,9 +499,18 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass, jobjec
ref = reinterpret_cast<DMatrixHandle>(refptr.get()[0]);
}

auto ret = XGQuantileDMatrixCreateFromCallback(&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next,
config, &result);
int ret = 0;
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
if (is_extmem) {
ret = XGQuantileDMatrixCreateFromCallback(&proxy, proxy.GetDMatrixHandle(), ref, Reset, Next,
config, &result);
} else {
ret = XGExtMemQuantileDMatrixCreateFromCallback(&proxy, proxy.GetDMatrixHandle(), ref, Reset,
Next, config, &result);
}

JVM_CHECK_CALL(ret);
setHandle(jenv, jout, result);
return ret;
}
} // namespace xgboost::jni
} // namespace xgboost::jni
24 changes: 19 additions & 5 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1298,9 +1298,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllred
}

namespace xgboost::jni {
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, jobject jdata_iter,
jobject jref_iter, char const *config,
jlongArray jout);
int QdmFromCallback(JNIEnv *jenv, jobject jdata_iter, jlongArray jref, char const *config,
bool is_extmem, jlongArray jout);
} // namespace xgboost::jni

/*
Expand All @@ -1309,14 +1308,29 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, j
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback(
JNIEnv *jenv, jclass, jobject jdata_iter, jlongArray jref, jstring jconf,
jlongArray jout) {
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
[&](char const *ptr) {
jenv->ReleaseStringUTFChars(jconf, ptr);
}};
return xgboost::jni::QdmFromCallback(jenv, jdata_iter, jref, conf.get(), false, jout);
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGExtMemQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;[JLjava/lang/String;[J)I
*/
JNIEXPORT jint JNICALL
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGExtMemQuantileDMatrixCreateFromCallback(
JNIEnv *jenv, jclass jcls, jobject jdata_iter, jlongArray jref, jstring jconf,
jlongArray jout) {
std::unique_ptr<char const, Deleter<char const>> conf{jenv->GetStringUTFChars(jconf, nullptr),
[&](char const *ptr) {
jenv->ReleaseStringUTFChars(jconf, ptr);
}};
return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref,
conf.get(), jout);
return xgboost::jni::QdmFromCallback(jenv, jdata_iter, jref, conf.get(), true, jout);
}

/*
Expand Down
8 changes: 8 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 57d0108

Please sign in to comment.