From 7d6d31a33ea3b7202cef2f145583f6dd4f996817 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 26 Jun 2024 15:24:48 +0800 Subject: [PATCH] Rewrite cudf column and cudf column batch. (#21) --- jvm-packages/checkstyle.xml | 2 +- .../example/spark/SparkMLlibPipeline.scala | 4 +- .../scala/example/spark/SparkTraining.scala | 8 +- jvm-packages/xgboost4j-spark-gpu/pom.xml | 6 + .../ml/dmlc/xgboost4j/java/CudfColumn.java | 113 ++++++++-------- .../dmlc/xgboost4j/java/CudfColumnBatch.java | 121 ++++++++++++------ .../ml/dmlc/xgboost4j/java/CudfUtils.java | 98 -------------- .../dmlc/xgboost4j/java/GpuColumnBatch.java | 76 ----------- .../scala/spark/GpuXGBoostPlugin.scala | 32 ++++- .../ml/dmlc/xgboost4j/java/BoosterTest.java | 57 +++------ .../ml/dmlc/xgboost4j/java/DMatrixTest.java | 53 +++++++- .../scala/QuantileDMatrixSuite.scala | 3 +- .../xgboost4j/scala/spark/XXXXXSuite.scala | 2 +- .../ml/dmlc/xgboost4j/scala/spark/Utils.scala | 3 +- .../scala/spark/XGBoostClassifier.scala | 4 +- .../scala/spark/XGBoostEstimator.scala | 13 -- .../xgboost4j/scala/spark/XGBoostPlugin.scala | 4 +- .../spark/params/DartBoosterParams.scala | 1 - .../spark/params/LearningTaskParams.scala | 4 - .../spark/params/ParamMapConversion.scala | 3 - .../scala/spark/params/RabitParams.scala | 2 - .../spark/params/TreeBoosterParams.scala | 1 - .../scala/spark/params/XGBoostParams.scala | 6 +- .../apache/spark/ml/xgboost/SparkUtils.scala | 2 - .../scala/spark/XGBoostEstimatorSuite.scala | 9 +- .../xgboost4j/scala/spark/XGBoostSuite.scala | 2 +- .../java/ml/dmlc/xgboost4j/java/Column.java | 14 +- .../ml/dmlc/xgboost4j/java/ColumnBatch.java | 75 +---------- .../java/ml/dmlc/xgboost4j/java/DMatrix.java | 8 +- .../xgboost4j/src/native/xgboost4j-gpu.cu | 20 +-- 30 files changed, 281 insertions(+), 465 deletions(-) delete mode 100644 jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfUtils.java delete mode 100644 jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/GpuColumnBatch.java diff --git a/jvm-packages/checkstyle.xml b/jvm-packages/checkstyle.xml index 88ae2122e279..57566da71dbe 100644 --- a/jvm-packages/checkstyle.xml +++ b/jvm-packages/checkstyle.xml @@ -133,7 +133,7 @@ - + diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala index 1b46d2f050bb..26a68f085fbb 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala @@ -88,11 +88,9 @@ object SparkMLlibPipeline { "max_depth" -> 2, "objective" -> "multi:softprob", "num_class" -> 3, - "num_round" -> 100, - "num_workers" -> numWorkers, "device" -> device ) - ) + ).setNumRound(10).setNumWorkers(numWorkers) booster.setFeaturesCol("features") booster.setLabelCol("classIndex") val labelConverter = new IndexToString() diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala index 173c8f432982..5be641fef773 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala @@ -80,13 +80,13 @@ private[spark] def run(spark: SparkSession, inputPath: String, "max_depth" -> 2, "objective" -> "multi:softprob", "num_class" -> 3, - "num_round" -> 100, - "num_workers" -> numWorkers, - "device" -> device, - "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2)) + "device" -> device) val xgbClassifier = new XGBoostClassifier(xgbParam). setFeaturesCol("features"). setLabelCol("classIndex") + .setNumWorkers(numWorkers) + .setNumRound(10) + .setEvalDataset(eval1) val xgbClassificationModel = xgbClassifier.fit(train) xgbClassificationModel.transform(test) } diff --git a/jvm-packages/xgboost4j-spark-gpu/pom.xml b/jvm-packages/xgboost4j-spark-gpu/pom.xml index 911782b8d86a..cf8648a8d33b 100644 --- a/jvm-packages/xgboost4j-spark-gpu/pom.xml +++ b/jvm-packages/xgboost4j-spark-gpu/pom.xml @@ -56,6 +56,12 @@ ${spark.rapids.version} provided + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.jackson.version} + provided + junit junit diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java index 32c64eadc360..14b149bd1091 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java @@ -16,93 +16,102 @@ package ml.dmlc.xgboost4j.java; +import java.util.ArrayList; +import java.util.List; + import ai.rapids.cudf.BaseDeviceMemoryBuffer; -import ai.rapids.cudf.BufferType; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; /** - * This class is composing of base data with Apache Arrow format from Cudf ColumnVector. - * It will be used to generate the cuda array interface. + * CudfColumn is the CUDF column representing, providing the cuda array interface */ +@JsonInclude(JsonInclude.Include.NON_NULL) public class CudfColumn extends Column { + private List shape = new ArrayList<>(); // row count + private List data = new ArrayList<>(); // gpu data buffer address + private String typestr; + private int version = 1; + private CudfColumn mask = null; + + public CudfColumn(long shape, long data, String typestr, int version) { + this.shape.add(shape); + this.data.add(data); + this.data.add(false); + this.typestr = typestr; + this.version = version; + } - private final long dataPtr; // gpu data buffer address - private final long shape; // row count - private final long validPtr; // gpu valid buffer address - private final int typeSize; // type size in bytes - private final String typeStr; // follow array interface spec - private final long nullCount; // null count - - private String arrayInterface = null; // the cuda array interface - + /** + * Create CudfColumn according to ColumnVector + */ public static CudfColumn from(ColumnVector cv) { - BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA); - BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY); - long validPtr = 0; - if (validBuffer != null) { - validPtr = validBuffer.getAddress(); - } + BaseDeviceMemoryBuffer dataBuffer = cv.getData(); + assert dataBuffer != null; + DType dType = cv.getType(); String typeStr = ""; if (dType == DType.FLOAT32 || dType == DType.FLOAT64 || - dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS || - dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS || - dType == DType.TIMESTAMP_SECONDS) { + dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS || + dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS || + dType == DType.TIMESTAMP_SECONDS) { typeStr = " getShape() { + return shape; } - public long getDataPtr() { - return dataPtr; + public List getData() { + return data; } - public long getShape() { - return shape; + public String getTypestr() { + return typestr; } - public long getValidPtr() { - return validPtr; + public int getVersion() { + return version; } - public int getTypeSize() { - return typeSize; + public CudfColumn getMask() { + return mask; } - public String getTypeStr() { - return typeStr; + public void setMask(CudfColumn mask) { + this.mask = mask; } - public long getNullCount() { - return nullCount; + @Override + public String toJson() { + ObjectMapper mapper = new ObjectMapper(); + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + try { + List objects = new ArrayList<>(1); + objects.add(this); + return mapper.writeValueAsString(objects); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumnBatch.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumnBatch.java index bba0072166dd..90b394e5a1c5 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumnBatch.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumnBatch.java @@ -16,71 +16,108 @@ package ml.dmlc.xgboost4j.java; +import java.util.List; +import java.util.stream.Collectors; import java.util.stream.IntStream; import ai.rapids.cudf.Table; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; /** - * Class to wrap CUDF Table to generate the cuda array interface. + * CudfColumnBatch wraps multiple CudfColumns to provide the cuda + * array interface json string for all columns. */ public class CudfColumnBatch extends ColumnBatch { - private final Table feature; - private final Table label; - private final Table weight; - private final Table baseMargin; - - public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) { - this.feature = feature; - this.label = labels; - this.weight = weights; - this.baseMargin = baseMargins; + @JsonIgnore + private final Table featureTable; + @JsonIgnore + private final Table labelTable; + @JsonIgnore + private final Table weightTable; + @JsonIgnore + private final Table baseMarginTable; + + private List features; + private List label; + private List weight; + private List baseMargin; + + public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable, + Table baseMarginTable) { + this.featureTable = featureTable; + this.labelTable = labelTable; + this.weightTable = weightTable; + this.baseMarginTable = baseMarginTable; + + features = initializeCudfColumns(featureTable); + if (labelTable != null) { + assert labelTable.getNumberOfColumns() == 1; + label = initializeCudfColumns(labelTable); + } + + if (weightTable != null) { + assert weightTable.getNumberOfColumns() == 1; + weight = initializeCudfColumns(weightTable); + } + + if (baseMarginTable != null) { + baseMargin = initializeCudfColumns(baseMarginTable); + } } - @Override - public String getFeatureArrayInterface() { - return getArrayInterface(this.feature); + private List initializeCudfColumns(Table table) { + assert table != null && table.getNumberOfColumns() > 0; + + return IntStream.range(0, table.getNumberOfColumns()) + .mapToObj(table::getColumn) + .map(CudfColumn::from) + .collect(Collectors.toList()); } - @Override - public String getLabelsArrayInterface() { - return getArrayInterface(this.label); + public List getFeatures() { + return features; } - @Override - public String getWeightsArrayInterface() { - return getArrayInterface(this.weight); + public List getLabel() { + return label; } - @Override - public String getBaseMarginsArrayInterface() { - return getArrayInterface(this.baseMargin); + public List getWeight() { + return weight; } - @Override - public void close() { - if (feature != null) feature.close(); - if (label != null) label.close(); - if (weight != null) weight.close(); - if (baseMargin != null) baseMargin.close(); + public List getBaseMargin() { + return baseMargin; } - private String getArrayInterface(Table table) { - if (table == null || table.getNumberOfColumns() == 0) { - return ""; + public String toJson() { + ObjectMapper mapper = new ObjectMapper(); + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + try { + return mapper.writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); } - return CudfUtils.buildArrayInterface(getAsCudfColumn(table)); } - private CudfColumn[] getAsCudfColumn(Table table) { - if (table == null || table.getNumberOfColumns() == 0) { - // This will never happen. - return new CudfColumn[]{}; + @Override + public String toFeaturesJson() { + ObjectMapper mapper = new ObjectMapper(); + try { + return mapper.writeValueAsString(features); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); } - - return IntStream.range(0, table.getNumberOfColumns()) - .mapToObj((i) -> table.getColumn(i)) - .map(CudfColumn::from) - .toArray(CudfColumn[]::new); } + @Override + public void close() { + if (featureTable != null) featureTable.close(); + if (labelTable != null) labelTable.close(); + if (weightTable != null) weightTable.close(); + if (baseMarginTable != null) baseMarginTable.close(); + } } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfUtils.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfUtils.java deleted file mode 100644 index 5863e4f0b3c7..000000000000 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfUtils.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - Copyright (c) 2021-2024 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.java; - -import java.util.ArrayList; - -/** - * Cudf utilities to build cuda array interface against {@link CudfColumn} - */ -class CudfUtils { - - /** - * Build the cuda array interface based on CudfColumn(s) - * @param cudfColumns the CudfColumn(s) to be built - * @return the json format of cuda array interface - */ - public static String buildArrayInterface(CudfColumn... cudfColumns) { - return new Builder().add(cudfColumns).build(); - } - - // Helper class to build array interface string - private static class Builder { - private ArrayList colArrayInterfaces = new ArrayList(); - - private Builder add(CudfColumn... columns) { - if (columns == null || columns.length <= 0) { - throw new IllegalArgumentException("At least one ColumnData is required."); - } - for (CudfColumn cd : columns) { - colArrayInterfaces.add(buildColumnObject(cd)); - } - return this; - } - - private String build() { - StringBuilder builder = new StringBuilder(); - builder.append("["); - for (int i = 0; i < colArrayInterfaces.size(); i++) { - builder.append(colArrayInterfaces.get(i)); - if (i != colArrayInterfaces.size() - 1) { - builder.append(","); - } - } - builder.append("]"); - return builder.toString(); - } - - /** build the whole column information including data and valid info */ - private String buildColumnObject(CudfColumn column) { - if (column.getDataPtr() == 0) { - throw new IllegalArgumentException("Empty column data is NOT accepted!"); - } - if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) { - throw new IllegalArgumentException("Empty type string is NOT accepted!"); - } - - StringBuilder builder = new StringBuilder(); - String colData = buildMetaObject(column.getDataPtr(), column.getShape(), - column.getTypeStr()); - builder.append("{"); - builder.append(colData); - if (column.getValidPtr() != 0 && column.getNullCount() != 0) { - String validString = buildMetaObject(column.getValidPtr(), column.getShape(), " indices) { - if (indices == null || indices.size() == 0) { - return null; - } - - int len = indices.size(); - ColumnVector[] cv = new ColumnVector[len]; - for (int i = 0; i < len; i++) { - int index = indices.get(i); - if (index >= table.getNumberOfColumns()) { - throw new RuntimeException("Wrong index"); - } - cv[i] = table.getColumn(index); - } - - return new Table(cv); - } - - public StructType getSchema() { - return schema; - } - -} diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala index 7eff5794dc81..56fd1675526f 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala @@ -23,7 +23,6 @@ import ai.rapids.cudf.Table import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils} import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext -import org.apache.spark.ml.functions.array_to_vector import org.apache.spark.ml.param.Param import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} @@ -31,12 +30,12 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.vectorized.ColumnarBatch -import ml.dmlc.xgboost4j.java.{CudfColumnBatch, GpuColumnBatch} +import ml.dmlc.xgboost4j.java.CudfColumnBatch import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix} import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol /** - * GpuXGBoostPlugin is the XGBoost plugin which leverage spark-rapids + * GpuXGBoostPlugin is the XGBoost plugin which leverages spark-rapids * to accelerate the XGBoost from ETL to train. */ class GpuXGBoostPlugin extends XGBoostPlugin { @@ -121,9 +120,9 @@ class GpuXGBoostPlugin extends XGBoostPlugin { /** build QuantilDMatrix on the executor side */ def buildQuantileDMatrix(iter: Iterator[Table]): QuantileDMatrix = { val colBatchIter = iter.map { table => - withResource(new GpuColumnBatch(table, null)) { batch => + withResource(new GpuColumnBatch(table)) { batch => new CudfColumnBatch( - batch.select(indices.featureIds.get.map(Integer.valueOf).asJava), + batch.select(indices.featureIds.get), batch.select(indices.labelId), batch.select(indices.weightId.getOrElse(-1)), batch.select(indices.marginId.getOrElse(-1))); @@ -219,9 +218,8 @@ class GpuXGBoostPlugin extends XGBoostPlugin { if (tableIters.hasNext) { val dataTypes = originalSchema.fields.map(x => x.dataType) iter = withResource(tableIters.next()) { table => - val gpuColumnBatch = new GpuColumnBatch(table, originalSchema) // Create DMatrix - val featureTable = gpuColumnBatch.select(featureIds.map(Integer.valueOf).asJava) + val featureTable = new GpuColumnBatch(table).select(featureIds) if (featureTable == null) { throw new RuntimeException("Something wrong for feature indices") } @@ -279,3 +277,23 @@ class GpuXGBoostPlugin extends XGBoostPlugin { model.postTransform(output, pred).toDF() } } + +private class GpuColumnBatch(table: Table) extends AutoCloseable { + + def select(index: Int): Table = { + select(Seq(index)) + } + + def select(indices: Seq[Int]): Table = { + if (!indices.forall(index => index < table.getNumberOfColumns && index >= 0)) { + return null; + } + new Table(indices.map(table.getColumn): _*) + } + + override def close(): Unit = { + if (Option(table).isDefined) { + table.close() + } + } +} diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/BoosterTest.java b/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/BoosterTest.java index cee43f1beda1..50d25765edb2 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/BoosterTest.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/BoosterTest.java @@ -16,16 +16,17 @@ package ml.dmlc.xgboost4j.java; -import ai.rapids.cudf.*; -import junit.framework.TestCase; -import org.junit.Test; - import java.io.File; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.Table; +import junit.framework.TestCase; +import org.junit.Test; + /** * Tests the BoosterTest trained by DMatrix * @@ -35,27 +36,7 @@ public class BoosterTest { @Test public void testBooster() throws XGBoostError { - String trainingDataPath = "../../demo/data/veterans_lung_cancer.csv"; - Schema schema = Schema.builder() - .column(DType.FLOAT32, "A") - .column(DType.FLOAT32, "B") - .column(DType.FLOAT32, "C") - .column(DType.FLOAT32, "D") - - .column(DType.FLOAT32, "E") - .column(DType.FLOAT32, "F") - .column(DType.FLOAT32, "G") - .column(DType.FLOAT32, "H") - - .column(DType.FLOAT32, "I") - .column(DType.FLOAT32, "J") - .column(DType.FLOAT32, "K") - .column(DType.FLOAT32, "L") - - .column(DType.FLOAT32, "label") - .build(); - CSVOptions opts = CSVOptions.builder() - .hasHeader().build(); + String resourcePath = getClass().getResource("/binary.train.parquet").getFile(); int maxBin = 16; int round = 10; @@ -72,33 +53,32 @@ public void testBooster() throws XGBoostError { } }; - try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) { - ColumnVector[] df = new ColumnVector[10]; - // exclude the first two columns, they are label bounds and contain inf. - for (int i = 2; i < 12; ++i) { - df[i - 2] = tmpTable.getColumn(i); + try (Table table = Table.readParquet(new File(resourcePath))) { + ColumnVector[] features = new ColumnVector[6]; + for (int i = 0; i < 6; i++) { + features[i] = table.getColumn(i); } - try (Table X = new Table(df);) { + + try (Table X = new Table(features)) { ColumnVector[] labels = new ColumnVector[1]; - labels[0] = tmpTable.getColumn(12); + labels[0] = table.getColumn(6); - try (Table y = new Table(labels);) { + try (Table y = new Table(labels)) { CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null); - CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12)); + CudfColumn labelColumn = CudfColumn.from(y.getColumn(0)); - //set watchList + // train XGBoost Booster base on DMatrix HashMap watches = new HashMap<>(); - DMatrix dMatrix1 = new DMatrix(batch, Float.NaN, 1); dMatrix1.setLabel(labelColumn); watches.put("train", dMatrix1); Booster model1 = XGBoost.train(dMatrix1, paramMap, round, watches, null, null); + // train XGBoost Booster base on QuantileDMatrix List tables = new LinkedList<>(); tables.add(batch); DMatrix incrementalDMatrix = new QuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1); - //set watchList HashMap watches1 = new HashMap<>(); watches1.put("train", incrementalDMatrix); Booster model2 = XGBoost.train(incrementalDMatrix, paramMap, round, watches1, null, null); @@ -106,12 +86,11 @@ public void testBooster() throws XGBoostError { float[][] predicat1 = model1.predict(dMatrix1); float[][] predicat2 = model2.predict(dMatrix1); - for (int i = 0; i < tmpTable.getRowCount(); i++) { + for (int i = 0; i < table.getRowCount(); i++) { TestCase.assertTrue(predicat1[i][0] - predicat2[i][0] < 1e-6); } } } } } - } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index 905b448a6892..4293486a97b2 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -16,14 +16,16 @@ package ml.dmlc.xgboost4j.java; -import ai.rapids.cudf.Table; -import junit.framework.TestCase; -import org.junit.Test; - import java.util.Arrays; import java.util.LinkedList; import java.util.List; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.Table; +import junit.framework.TestCase; +import org.junit.Test; + import static org.junit.Assert.assertArrayEquals; /** @@ -135,4 +137,47 @@ private float[] convertFloatTofloat(Float[]... datas) { } return floatArray; } + + @Test + public void testMakingDMatrixViaArray() { +// Float[][] features1 = { +// {1.0f, 12.0f}, +// {2.0f, 13.0f}, +// null, +// {4.0f, null}, +// {5.0f, 16.0f} +// }; +// +// Float[] label1 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; +// +// Table X1 = new Table.TestBuilder().column(features1).build(); +// Table y1 = new Table.TestBuilder().column(label1).build(); +// +// ColumnVector t = X1.getColumn(0); +// ColumnView cv = t.getChildColumnView(0); +// // +// System.out.println("----"); +// +// Float[][] features2 = { +// {6.0f, 17.0f}, +// {7.0f, 18.0f}, +// }; +// Float[] label2 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; +// Table X2 = new Table.TestBuilder().column(features2).build(); +// Table y2 = new Table.TestBuilder().column(label2).build(); +// +// List tables = new LinkedList<>(); +// tables.add(new CudfColumnBatch(X1, y1, null, null)); +// tables.add(new CudfColumnBatch(X2, y2, null, null)); +// +// try { +// DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1); +// } catch (XGBoostError e) { +// throw new RuntimeException(e); +// } +// +// System.out.println("--------------"); + + } + } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala index 574ac4cc6c14..1c8b36af299d 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/QuantileDMatrixSuite.scala @@ -44,8 +44,7 @@ class QuantileDMatrixSuite extends AnyFunSuite { withResource(new Table.TestBuilder().column(baseMargin1: _*).build) { m_0 => withResource(new Table.TestBuilder() .column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f.asInstanceOf[java.lang.Float]) - .column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build) - { X_1 => + .column(1.2f, 1.4f, null.asInstanceOf[java.lang.Float], 12.6f, 10.10f).build) { X_1 => withResource(new Table.TestBuilder().column(label2: _*).build) { y_1 => withResource(new Table.TestBuilder().column(weight2: _*).build) { w_1 => withResource(new Table.TestBuilder().column(baseMargin2: _*).build) { m_1 => diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala index a8ba1c1b225a..f98c9614ab68 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala @@ -83,7 +83,7 @@ class XXXXXSuite extends AnyFunSuite with GpuTestSuite { val out = model.transform(df) out.printSchema() - out.show(150, false) + out.show(150) // model.write.overwrite().save("/tmp/model/") // val loadedModel = XGBoostClassificationModel.load("/tmp/model") // println(loadedModel.getNumRound) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala index 9180281a44a8..6c9716089419 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, T import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} -// based on org.apache.spark.util copy /paste object Utils { private[spark] implicit class XGBLabeledPointFeatures( diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index cb734d32a8f6..408a10011f92 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -24,14 +24,13 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader} import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams} -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions.{col, udf} import org.json4s.DefaultFormats import ml.dmlc.xgboost4j.scala.Booster import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{binaryClassificationObjs, multiClassificationObjs} - class XGBoostClassifier(override val uid: String, private[spark] val xgboostParams: Map[String, Any]) extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel] @@ -109,6 +108,7 @@ class XGBoostClassifier(override val uid: String, XGBoostClassificationModel = { new XGBoostClassificationModel(uid, numberClasses, booster, Some(summary)) } + } object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 810be139f2c7..7f0d26370c86 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -41,7 +41,6 @@ import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} import ml.dmlc.xgboost4j.scala.spark.Utils.MLVectorToXGBLabeledPoint import ml.dmlc.xgboost4j.scala.spark.params.{ParamUtils, _} - /** * Hold the column index */ @@ -378,10 +377,8 @@ private[spark] trait XGBoostEstimator[ } else { setNthread(taskCpus) } - } - def train(dataset: Dataset[_]): M = { validate(dataset) @@ -403,11 +400,6 @@ private[spark] trait XGBoostEstimator[ } override def copy(extra: ParamMap): Learner = defaultCopy(extra).asInstanceOf[Learner] - - // Not used in XGBoost - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, true) - } } /** Indicate what to be predicted */ @@ -441,11 +433,6 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML def summary: Option[XGBoostTrainingSummary] - // Not used in XGBoost - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, false) - } - protected[spark] def postTransform(dataset: Dataset[_], pred: PredictedColumns): Dataset[_] = { var output = dataset // Convert leaf/contrib to the vector from array diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala index 3e18b6439988..dda82f97968b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala @@ -41,7 +41,9 @@ trait XGBoostPlugin extends Serializable { estimator: XGBoostEstimator[T, M], dataset: Dataset[_]): RDD[Watches] - + /** + * Transform the dataset + */ def transform[M <: XGBoostModel[M]](model: XGBoostModel[M], dataset: Dataset[_]): DataFrame } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DartBoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DartBoosterParams.scala index eb29da5712be..e9707999a1a1 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DartBoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DartBoosterParams.scala @@ -18,7 +18,6 @@ package ml.dmlc.xgboost4j.scala.spark.params import org.apache.spark.ml.param._ - /** * Dart booster parameters, more details can be found at * https://xgboost.readthedocs.io/en/stable/parameter.html# diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 5385875d29c7..8892100a1334 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -20,7 +20,6 @@ import scala.collection.immutable.HashSet import org.apache.spark.ml.param._ - /** * Specify the learning task and the corresponding learning objective. * More details can be found at @@ -28,7 +27,6 @@ import org.apache.spark.ml.param._ */ private[spark] trait LearningTaskParams extends Params { - final val objective = new Param[String](this, "objective", "Objective function used for training", ParamValidators.inArray(LearningTaskParams.supportedObjectives.toArray)) @@ -122,7 +120,6 @@ private[spark] trait LearningTaskParams extends Params { setDefault(objective -> "reg:squarederror", numClass -> 0, seed -> 0, seedPerIteration -> false, tweedieVariancePower -> 1.5, huberSlope -> 1, lambdarankPairMethod -> "mean", lambdarankUnbiased -> false, lambdarankBiasNorm -> 2, ndcgExpGain -> true) - } private[spark] object LearningTaskParams { @@ -141,5 +138,4 @@ private[spark] object LearningTaskParams { "pre@n", "ndcg-", "map-", "ndcg@n-", "map@n-", "poisson-nloglik", "gamma-nloglik", "cox-nloglik", "gamma-deviance", "tweedie-nloglik", "aft-nloglik", "interval-regression-accuracy") - } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala index fca272c78c8e..787cd753ba11 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala @@ -20,7 +20,6 @@ import scala.collection.mutable import org.apache.spark.ml.param._ - private[spark] trait ParamMapConversion extends NonXGBoostParams { /** @@ -64,5 +63,3 @@ private[spark] trait ParamMapConversion extends NonXGBoostParams { xgboostParams.toMap } } - - diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala index 26d94594e6b8..7a527fb37fc8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala @@ -42,5 +42,3 @@ private[spark] trait RabitParams extends Params with NonXGBoostParams { addNonXGBoostParam(rabitTrackerPort, rabitTrackerHostIp, rabitTrackerPort) } - - diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TreeBoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TreeBoosterParams.scala index 6b70f4187b95..7ea5966d459a 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TreeBoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TreeBoosterParams.scala @@ -20,7 +20,6 @@ import scala.collection.immutable.HashSet import org.apache.spark.ml.param._ - /** * TreeBoosterParams defines the XGBoost TreeBooster parameters for Spark * diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala index 604a9022ad6e..8345cab35149 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala @@ -20,12 +20,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.xgboost.SparkUtils -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types.StructType import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} - trait HasLeafPredictionCol extends Params { /** * Param for leaf prediction column name. @@ -232,8 +230,6 @@ private[spark] trait SchemaValidationTrait { fitting: Boolean): StructType = schema } - - /** * XGBoost ranking spark-specific parameters * diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/xgboost/SparkUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/xgboost/SparkUtils.scala index bcafa2cf7065..a49166972dd5 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/xgboost/SparkUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/xgboost/SparkUtils.scala @@ -90,6 +90,4 @@ object SparkUtils { nullable: Boolean = false): StructType = { SchemaUtils.appendColumn(schema, colName, dataType, nullable) } - - } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala index 484673922d74..c6c8c7f2e73c 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala @@ -39,23 +39,22 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu test("RuntimeParameter") { var runtimeParams = new XGBoostClassifier( - Map("device" -> "cpu", "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cpu")) .getRuntimeParameters(true) assert(!runtimeParams.runOnGpu) runtimeParams = new XGBoostClassifier( - Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1) .getRuntimeParameters(true) assert(runtimeParams.runOnGpu) runtimeParams = new XGBoostClassifier( - Map("device" -> "cpu", "tree_method" -> "gpu_hist", "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cpu", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1) .getRuntimeParameters(true) assert(runtimeParams.runOnGpu) runtimeParams = new XGBoostClassifier( - Map("device" -> "cuda", "tree_method" -> "gpu_hist", - "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cuda", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1) .getRuntimeParameters(true) assert(runtimeParams.runOnGpu) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index fcfa53c07582..3a45cf4448c0 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -108,7 +108,7 @@ class XGBoostSuite extends AnyFunSuite with PerTest { val rdd = df.rdd val runtimeParams = new XGBoostClassifier( - Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1) .getRuntimeParameters(true) assert(runtimeParams.runOnGpu) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Column.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Column.java index 5afb5ed5bf63..7555159dbdb1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Column.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Column.java @@ -17,23 +17,17 @@ package ml.dmlc.xgboost4j.java; /** - * The abstracted XGBoost Column to get the cuda array interface which is used to - * set the information for DMatrix. + * This Column abstraction provides an array interface JSON string, which is + * used to reconstruct columnar data within the XGBoost library. */ public abstract class Column implements AutoCloseable { /** - * Get the cuda array interface json string for the Column - *

- * This API will be called by - * {@link DMatrix#setLabel(Column)} - * {@link DMatrix#setWeight(Column)} - * {@link DMatrix#setBaseMargin(Column)} + * Return array interface json string for this Column */ - public abstract String getArrayInterfaceJson(); + public abstract String toJson(); @Override public void close() throws Exception { } - } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java index 798cfeac85c4..9bb48490b4f6 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java @@ -17,76 +17,11 @@ package ml.dmlc.xgboost4j.java; /** - * The abstracted XGBoost ColumnBatch to get array interface from columnar data format. - * For example, the cuDF dataframe which employs apache arrow specification. + * This class wraps multiple Column and provides the array interface json + * for all columns. */ -public abstract class ColumnBatch implements AutoCloseable { - /** - * Get the cuda array interface json string for the whole ColumnBatch including - * the must-have feature, label columns and the optional weight, base margin columns. - *

- * This function is be called by native code during iteration and can be made as private - * method. We keep it as public simply to silent the linter. - */ - public final String getArrayInterfaceJson() { - - StringBuilder builder = new StringBuilder(); - builder.append("{"); - String featureStr = this.getFeatureArrayInterface(); - if (featureStr == null || featureStr.isEmpty()) { - throw new RuntimeException("Feature array interface must not be empty"); - } else { - builder.append("\"features_str\":" + featureStr); - } - - String labelStr = this.getLabelsArrayInterface(); - if (labelStr == null || labelStr.isEmpty()) { - throw new RuntimeException("Label array interface must not be empty"); - } else { - builder.append(",\"label_str\":" + labelStr); - } - - String weightStr = getWeightsArrayInterface(); - if (weightStr != null && !weightStr.isEmpty()) { - builder.append(",\"weight_str\":" + weightStr); - } - - String baseMarginStr = getBaseMarginsArrayInterface(); - if (baseMarginStr != null && !baseMarginStr.isEmpty()) { - builder.append(",\"basemargin_str\":" + baseMarginStr); - } - - builder.append("}"); - return builder.toString(); - } - - /** - * Get the cuda array interface of the feature columns. - * The returned value must not be null or empty - */ - public abstract String getFeatureArrayInterface(); - - /** - * Get the cuda array interface of the label columns. - * The returned value must not be null or empty if we're creating - * QuantileDMatrix#QuantileDMatrix(Iterator, float, int, int) - */ - public abstract String getLabelsArrayInterface(); - - /** - * Get the cuda array interface of the weight columns. - * The returned value can be null or empty - */ - public abstract String getWeightsArrayInterface(); - - /** - * Get the cuda array interface of the base margin columns. - * The returned value can be null or empty - */ - public abstract String getBaseMarginsArrayInterface(); - - @Override - public void close() throws Exception { - } +public abstract class ColumnBatch extends Column { + /** Get features cuda array interface json string */ + public abstract String toFeaturesJson(); } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index 2e7540bd2b30..0e88c25d3fda 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -184,7 +184,7 @@ protected DMatrix(long handle) { */ public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError { long[] out = new long[1]; - String json = columnBatch.getFeatureArrayInterface(); + String json = columnBatch.toFeaturesJson(); if (json == null || json.isEmpty()) { throw new XGBoostError("Expecting non-empty feature columns' array interface"); } @@ -201,7 +201,7 @@ public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoo * @throws XGBoostError native error */ public void setLabel(Column column) throws XGBoostError { - setXGBDMatrixInfo("label", column.getArrayInterfaceJson()); + setXGBDMatrixInfo("label", column.toJson()); } /** @@ -212,7 +212,7 @@ public void setLabel(Column column) throws XGBoostError { * @throws XGBoostError native error */ public void setWeight(Column column) throws XGBoostError { - setXGBDMatrixInfo("weight", column.getArrayInterfaceJson()); + setXGBDMatrixInfo("weight", column.toJson()); } /** @@ -223,7 +223,7 @@ public void setWeight(Column column) throws XGBoostError { * @throws XGBoostError native error */ public void setBaseMargin(Column column) throws XGBoostError { - setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson()); + setXGBDMatrixInfo("base_margin", column.toJson()); } private void setXGBDMatrixInfo(String type, String json) throws XGBoostError { diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu b/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu index 317be01adf9c..b784b21ec5f6 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu +++ b/jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu @@ -186,11 +186,11 @@ class DataIteratorProxy { void StageMetaInfo(Json json_interface) { CHECK(!IsA(json_interface)); auto json_map = get(json_interface); - if (json_map.find("label_str") == json_map.cend()) { + if (json_map.find("label") == json_map.cend()) { LOG(FATAL) << "Must have a label field."; } - Json label = json_interface["label_str"]; + Json label = json_interface["label"]; CHECK(!IsA(label)); labels_.emplace_back(new dh::device_vector); CopyMetaInfo(&label, labels_.back().get(), copy_stream_); @@ -200,8 +200,8 @@ class DataIteratorProxy { Json::Dump(label, &str); XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str()); - if (json_map.find("weight_str") != json_map.cend()) { - Json weight = json_interface["weight_str"]; + if (json_map.find("weight") != json_map.cend()) { + Json weight = json_interface["weight"]; CHECK(!IsA(weight)); weights_.emplace_back(new dh::device_vector); CopyMetaInfo(&weight, weights_.back().get(), copy_stream_); @@ -211,8 +211,8 @@ class DataIteratorProxy { XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str()); } - if (json_map.find("basemargin_str") != json_map.cend()) { - Json basemargin = json_interface["basemargin_str"]; + if (json_map.find("baseMargin") != json_map.cend()) { + Json basemargin = json_interface["baseMargin"]; base_margins_.emplace_back(new dh::device_vector); CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_); margin_interfaces_.emplace_back(basemargin); @@ -249,11 +249,11 @@ class DataIteratorProxy { // batch should be ColumnBatch from jvm jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_); jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_); - jmethodID getArrayInterfaceJson = CheckJvmCall(jenv_->GetMethodID( - batch_class, "getArrayInterfaceJson", "()Ljava/lang/String;"), jenv_); + jmethodID toJson = CheckJvmCall(jenv_->GetMethodID( + batch_class, "toJson", "()Ljava/lang/String;"), jenv_); auto jinterface = - static_cast(jenv_->CallObjectMethod(batch, getArrayInterfaceJson)); + static_cast(jenv_->CallObjectMethod(batch, toJson)); CheckJvmCall(jinterface, jenv_); char const *c_interface_str = CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_); @@ -281,7 +281,7 @@ class DataIteratorProxy { CHECK(!IsA(json_interface)); StageMetaInfo(json_interface); - Json features = json_interface["features_str"]; + Json features = json_interface["features"]; auto json_columns = get(features); std::vector> interfaces;