Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] change DeviceQuantileDmatrix into QuantileDMatrix #8461

Merged
merged 4 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass j
common::AssertGPUSupport();
API_END();
}
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
char const *config, jlongArray jout) {
API_BEGIN();
common::AssertGPUSupport();
API_END();
}
} // namespace jni
} // namespace xgboost
#endif // XGBOOST_USE_CUDA
18 changes: 17 additions & 1 deletion jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ int Next(DataIterHandle self) {
}
} // anonymous namespace

XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
Expand All @@ -392,5 +392,21 @@ XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass
setHandle(jenv, jout, result);
return ret;
}

XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jdata_iter, jobject jref_iter,
char const *config, jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
DMatrixHandle result;

std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
if (jref_iter) {
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
}
auto ret = XGQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
setHandle(jenv, jout, result);
return ret;
}
} // namespace jni
} // namespace xgboost
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

Expand Down Expand Up @@ -107,7 +107,7 @@ public void testBooster() throws XGBoostError {

List<ColumnBatch> tables = new LinkedList<>();
tables.add(batch);
DMatrix incrementalDMatrix = new DeviceQuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
DMatrix incrementalDMatrix = new QuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
//set watchList
HashMap<String, DMatrix> watches1 = new HashMap<>();
watches1.put("train", incrementalDMatrix);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.QuantileDMatrix;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;

Expand Down Expand Up @@ -117,7 +117,7 @@ public void testCreateFromColumnDataIterator() throws XGBoostError {
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));

DMatrix dmat = new DeviceQuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);

float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import ai.rapids.cudf.Table
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch

class DeviceQuantileDMatrixSuite extends FunSuite {
class QuantileDMatrixSuite extends FunSuite {

test("DeviceQuantileDMatrix test") {
test("QuantileDMatrix test") {

val label1 = Array[java.lang.Float](25f, 21f, 22f, 20f, 24f)
val weight1 = Array[java.lang.Float](1.3f, 2.31f, 0.32f, 3.3f, 1.34f)
Expand All @@ -51,8 +51,7 @@ class DeviceQuantileDMatrixSuite extends FunSuite {
val batches = new ArrayBuffer[CudfColumnBatch]()
batches += new CudfColumnBatch(X_0, y_0, w_0, m_0)
batches += new CudfColumnBatch(X_1, y_1, w_1, m_1)
val dmatrix = new DeviceQuantileDMatrix(batches.toIterator, 0.0f, 8, 1)

val dmatrix = new QuantileDMatrix(batches.toIterator, 0.0f, 8, 1)
assert(dmatrix.getLabel.sameElements(label1 ++ label2))
assert(dmatrix.getWeight.sameElements(weight1 ++ weight2))
assert(dmatrix.getBaseMargin.sameElements(baseMargin1 ++ baseMargin2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.gpu.java.CudfColumnBatch
import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, DeviceQuantileDMatrix}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix}
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor}
import org.apache.commons.logging.LogFactory
Expand Down Expand Up @@ -532,7 +532,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
}

/**
* Build DeviceQuantileDMatrix based on GpuColumnBatches
* Build QuantileDMatrix based on GpuColumnBatches
*
* @param iter a sequence of GpuColumnBatch
* @param indices indicate the feature, label, weight, base margin column ids.
Expand All @@ -546,7 +546,7 @@ object GpuPreXGBoost extends PreXGBoostProvider {
missing: Float,
maxBin: Int): DMatrix = {
val rapidsIterator = new RapidsIterator(iter, indices)
new DeviceQuantileDMatrix(rapidsIterator, missing, maxBin, 1)
new QuantileDMatrix(rapidsIterator, missing, maxBin, 1)
}

// zip all the Columnar RDDs into one RDD containing named column data batch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public final String getArrayInterfaceJson() {
/**
* Get the cuda array interface of the label columns.
* The returned value must not be null or empty if we're creating
* {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)}
* {@link QuantileDMatrix#QuantileDMatrix(Iterator, float, int, int)}
*/
public abstract String getLabelsArrayInterface();

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package ml.dmlc.xgboost4j.java;

import java.util.Iterator;

/**
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public QuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
String conf = getConfig(missing, maxBin, nthread);
XGBoostJNI.checkCall(XGBoostJNI.XGQuantileDMatrixCreateFromCallback(
iter, (java.util.Iterator<ColumnBatch>)null, conf, out));
handle = out[0];
}

@Override
public void setLabel(Column column) throws XGBoostError {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please remind me what's the reason behind setters not being supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is renamed by DeviceQuantileDMatrix which is not supporting setter before (when the PR was merged). Will file another PR to support setter if it indeed has supported that.

throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}

@Override
public void setWeight(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}

@Override
public void setBaseMargin(Column column) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}

@Override
public void setLabel(float[] labels) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setLabel.");
}

@Override
public void setWeight(float[] weights) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setWeight.");
}

@Override
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}

@Override
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.");
}

@Override
public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("QuantileDMatrix does not support setGroup.");
}

private String getConfig(float missing, int maxBin, int nthread) {
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
missing, maxBin, nthread);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count,
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);

@Deprecated
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);

public final static native int XGQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, java.util.Iterator<ColumnBatch> ref, String config, long[] out);

public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ package ml.dmlc.xgboost4j.scala

import _root_.scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, DeviceQuantileDMatrix => JDeviceQuantileDMatrix}
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, XGBoostError, QuantileDMatrix => JQuantileDMatrix}

class DeviceQuantileDMatrix private[scala](
private[scala] override val jDMatrix: JDeviceQuantileDMatrix) extends DMatrix(jDMatrix) {
class QuantileDMatrix private[scala](
private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) {

/**
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
* Create QuantileDMatrix from iterator based on the cuda array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
Expand All @@ -33,7 +33,7 @@ class DeviceQuantileDMatrix private[scala](
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int, nthread: Int) {
this(new JDeviceQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
this(new JQuantileDMatrix(iter.asJava, missing, maxBin, nthread))
}

/**
Expand All @@ -43,7 +43,7 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setLabel(labels: Array[Float]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.")
throw new XGBoostError("QuantileDMatrix does not support setLabel.")

/**
* set weight of each instance
Expand All @@ -52,7 +52,7 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setWeight(weights: Array[Float]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.")
throw new XGBoostError("QuantileDMatrix does not support setWeight.")

/**
* if specified, xgboost will start from this init margin
Expand All @@ -62,7 +62,7 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(baseMargin: Array[Float]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")

/**
* if specified, xgboost will start from this init margin
Expand All @@ -72,7 +72,7 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(baseMargin: Array[Array[Float]]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")

/**
* Set group sizes of DMatrix (used for ranking)
Expand All @@ -81,27 +81,27 @@ class DeviceQuantileDMatrix private[scala](
*/
@throws(classOf[XGBoostError])
override def setGroup(group: Array[Int]): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.")
throw new XGBoostError("QuantileDMatrix does not support setGroup.")

/**
* Set label of DMatrix from cuda array interface
*/
@throws(classOf[XGBoostError])
override def setLabel(column: Column): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.")
throw new XGBoostError("QuantileDMatrix does not support setLabel.")

/**
* set weight of dmatrix from column array interface
*/
@throws(classOf[XGBoostError])
override def setWeight(column: Column): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.")
throw new XGBoostError("QuantileDMatrix does not support setWeight.")

/**
* set base margin of dmatrix from column array interface
*/
@throws(classOf[XGBoostError])
override def setBaseMargin(column: Column): Unit =
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.")
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")

}
Loading