From d6d1d785e66313e119f17d1a599fb332f60eecb5 Mon Sep 17 00:00:00 2001 From: Bobby Date: Wed, 26 Jun 2024 11:31:38 +0800 Subject: [PATCH] update --- .../ml/dmlc/xgboost4j/java/CudfColumn.java | 4 +- .../dmlc/xgboost4j/java/CudfColumnBatch.java | 38 ++++++++++--------- .../ml/dmlc/xgboost4j/java/BoosterTest.java | 2 + 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java index 8c73111226bf..051ac98e0080 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumn.java @@ -105,7 +105,9 @@ public String toJson() { ObjectMapper mapper = new ObjectMapper(); mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); try { - return mapper.writeValueAsString(this); + List objects = new ArrayList<>(1); + objects.add(this); + return mapper.writeValueAsString(objects); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumnBatch.java b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumnBatch.java index 73d0b37ac9ca..8d2ee06aacbb 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumnBatch.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/CudfColumnBatch.java @@ -18,6 +18,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.Table; @@ -41,9 +43,18 @@ public class CudfColumnBatch extends ColumnBatch { private final Table baseMarginTable; private List features; - private CudfColumn label; - private CudfColumn weight; - private CudfColumn baseMargin; + private List label; + private List weight; + private List baseMargin; + + private List initializeCudfColumns(Table table) { + assert table != null && table.getNumberOfColumns() > 0; + + return IntStream.range(0, table.getNumberOfColumns()) + .mapToObj(table::getColumn) + .map(CudfColumn::from) + .collect(Collectors.toList()); + } public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable, Table baseMarginTable) { @@ -52,26 +63,19 @@ public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable, this.weightTable = weightTable; this.baseMarginTable = baseMarginTable; - features = new ArrayList<>(); - for (int index = 0; index < featureTable.getNumberOfColumns(); index++) { - ColumnVector cv = featureTable.getColumn(index); - features.add(CudfColumn.from(cv)); - } - + features = initializeCudfColumns(featureTable); if (labelTable != null) { assert labelTable.getNumberOfColumns() == 1; - label = CudfColumn.from(labelTable.getColumn(0)); + label = initializeCudfColumns(labelTable); } if (weightTable != null) { assert weightTable.getNumberOfColumns() == 1; - weight = CudfColumn.from(weightTable.getColumn(0)); + weight = initializeCudfColumns(weightTable); } - // TODO baseMargin should be an array for multi classification if (baseMarginTable != null) { - assert baseMarginTable.getNumberOfColumns() == 1; - baseMargin = CudfColumn.from(baseMarginTable.getColumn(0)); + baseMargin = initializeCudfColumns(baseMarginTable); } } @@ -79,15 +83,15 @@ public List getFeatures() { return features; } - public CudfColumn getLabel() { + public List getLabel() { return label; } - public CudfColumn getWeight() { + public List getWeight() { return weight; } - public CudfColumn getBaseMargin() { + public List getBaseMargin() { return baseMargin; } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/BoosterTest.java b/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/BoosterTest.java index cee43f1beda1..ee2737b3f88a 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/BoosterTest.java +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/java/ml/dmlc/xgboost4j/java/BoosterTest.java @@ -85,7 +85,9 @@ public void testBooster() throws XGBoostError { try (Table y = new Table(labels);) { CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null); + System.out.println(batch.toJson()); CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12)); + System.out.println(labelColumn.toJson()); //set watchList HashMap watches = new HashMap<>();