diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala index 7eff5794dc81..21a82a044556 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala @@ -23,7 +23,6 @@ import ai.rapids.cudf.Table import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils} import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext -import org.apache.spark.ml.functions.array_to_vector import org.apache.spark.ml.param.Param import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} @@ -31,7 +30,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.vectorized.ColumnarBatch -import ml.dmlc.xgboost4j.java.{CudfColumnBatch, GpuColumnBatch} +import ml.dmlc.xgboost4j.java.CudfColumnBatch import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix} import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol @@ -121,9 +120,9 @@ class GpuXGBoostPlugin extends XGBoostPlugin { /** build QuantilDMatrix on the executor side */ def buildQuantileDMatrix(iter: Iterator[Table]): QuantileDMatrix = { val colBatchIter = iter.map { table => - withResource(new GpuColumnBatch(table, null)) { batch => + withResource(new GpuColumnBatch(table)) { batch => new CudfColumnBatch( - batch.select(indices.featureIds.get.map(Integer.valueOf).asJava), + batch.select(indices.featureIds.get), batch.select(indices.labelId), batch.select(indices.weightId.getOrElse(-1)), batch.select(indices.marginId.getOrElse(-1))); @@ -219,9 +218,8 @@ class GpuXGBoostPlugin extends XGBoostPlugin { if (tableIters.hasNext) { val dataTypes = originalSchema.fields.map(x => x.dataType) iter = withResource(tableIters.next()) { table => - val gpuColumnBatch = new GpuColumnBatch(table, originalSchema) // Create DMatrix - val featureTable = gpuColumnBatch.select(featureIds.map(Integer.valueOf).asJava) + val featureTable = new GpuColumnBatch(table).select(featureIds) if (featureTable == null) { throw new RuntimeException("Something wrong for feature indices") } @@ -279,3 +277,23 @@ class GpuXGBoostPlugin extends XGBoostPlugin { model.postTransform(output, pred).toDF() } } + +private class GpuColumnBatch(table: Table) extends AutoCloseable { + + def select(index: Int): Table = { + select(Seq(index)) + } + + def select(indices: Seq[Int]): Table = { + if (!indices.forall(index => index < table.getNumberOfColumns && index >= 0)) { + return null; + } + new Table(indices.map(table.getColumn): _*) + } + + override def close(): Unit = { + if (Option(table).isDefined) { + table.close() + } + } +} diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index 0b1422236bab..4293486a97b2 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -140,52 +140,43 @@ private float[] convertFloatTofloat(Float[]... datas) { @Test public void testMakingDMatrixViaArray() { - // ColumnVector child1 = ColumnVector.fromFloats(1, 2, 3, 4, 5, 6); - // ColumnVector child2 = ColumnVector.fromFloats(11, 12, 13, 14, 15, 16); - // ColumnVector list = ColumnVector.makeList(child1, child2); - // child2.close(); - // child1.close(); - - Float[][] features1 = { - {1.0f, 12.0f}, - {2.0f, 13.0f}, - null, - {4.0f, null}, - {5.0f, 16.0f} - }; - - Float[] label1 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; - - Table X1 = new Table.TestBuilder().column(features1).build(); - Table y1 = new Table.TestBuilder().column(label1).build(); - - // HostColumnVector hcv = X1.getColumn(0).copyToHost(); - - // - ColumnVector t = X1.getColumn(0); - ColumnView cv = t.getChildColumnView(0); - // - System.out.println("----"); - - Float[][] features2 = { - {6.0f, 17.0f}, - {7.0f, 18.0f}, - }; - Float[] label2 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; - Table X2 = new Table.TestBuilder().column(features2).build(); - Table y2 = new Table.TestBuilder().column(label2).build(); - - List tables = new LinkedList<>(); - tables.add(new CudfColumnBatch(X1, y1, null, null)); - tables.add(new CudfColumnBatch(X2, y2, null, null)); - - try { - DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1); - } catch (XGBoostError e) { - throw new RuntimeException(e); - } - - System.out.println("--------------"); +// Float[][] features1 = { +// {1.0f, 12.0f}, +// {2.0f, 13.0f}, +// null, +// {4.0f, null}, +// {5.0f, 16.0f} +// }; +// +// Float[] label1 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; +// +// Table X1 = new Table.TestBuilder().column(features1).build(); +// Table y1 = new Table.TestBuilder().column(label1).build(); +// +// ColumnVector t = X1.getColumn(0); +// ColumnView cv = t.getChildColumnView(0); +// // +// System.out.println("----"); +// +// Float[][] features2 = { +// {6.0f, 17.0f}, +// {7.0f, 18.0f}, +// }; +// Float[] label2 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; +// Table X2 = new Table.TestBuilder().column(features2).build(); +// Table y2 = new Table.TestBuilder().column(label2).build(); +// +// List tables = new LinkedList<>(); +// tables.add(new CudfColumnBatch(X1, y1, null, null)); +// tables.add(new CudfColumnBatch(X2, y2, null, null)); +// +// try { +// DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1); +// } catch (XGBoostError e) { +// throw new RuntimeException(e); +// } +// +// System.out.println("--------------"); } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala index a8ba1c1b225a..f98c9614ab68 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XXXXXSuite.scala @@ -83,7 +83,7 @@ class XXXXXSuite extends AnyFunSuite with GpuTestSuite { val out = model.transform(df) out.printSchema() - out.show(150, false) + out.show(150) // model.write.overwrite().save("/tmp/model/") // val loadedModel = XGBoostClassificationModel.load("/tmp/model") // println(loadedModel.getNumRound)