Skip to content

Commit

Permalink
Gpu: Support rank and regressor (#10560)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Jul 17, 2024
1 parent e3ba9fc commit b1d0ef7
Show file tree
Hide file tree
Showing 22 changed files with 1,191 additions and 724 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,22 @@ public class CudfColumnBatch extends ColumnBatch {
private final Table weightTable;
@JsonIgnore
private final Table baseMarginTable;
@JsonIgnore
private final Table qidTable;

private List<CudfColumn> features;
private List<CudfColumn> label;
private List<CudfColumn> weight;
private List<CudfColumn> baseMargin;
private List<CudfColumn> qid;

public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
Table baseMarginTable) {
Table baseMarginTable, Table qidTable) {
this.featureTable = featureTable;
this.labelTable = labelTable;
this.weightTable = weightTable;
this.baseMarginTable = baseMarginTable;
this.qidTable = qidTable;

features = initializeCudfColumns(featureTable);
if (labelTable != null) {
Expand All @@ -66,6 +70,11 @@ public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
if (baseMarginTable != null) {
baseMargin = initializeCudfColumns(baseMarginTable);
}

if (qidTable != null) {
qid = initializeCudfColumns(qidTable);
}

}

private List<CudfColumn> initializeCudfColumns(Table table) {
Expand Down Expand Up @@ -93,6 +102,10 @@ public List<CudfColumn> getBaseMargin() {
return baseMargin;
}

public List<CudfColumn> getQid() {
return qid;
}

public String toJson() {
ObjectMapper mapper = new ObjectMapper();
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
Expand All @@ -119,5 +132,6 @@ public void close() {
if (labelTable != null) labelTable.close();
if (weightTable != null) weightTable.close();
if (baseMarginTable != null) baseMarginTable.close();
if (qidTable != null) qidTable.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class QuantileDMatrix private[scala](
private[scala] override val jDMatrix: JQuantileDMatrix) extends DMatrix(jDMatrix) {

/**
* Create QuantileDMatrix from iterator based on the cuda array interface
* Create QuantileDMatrix from iterator based on the array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
Expand Down Expand Up @@ -84,7 +84,7 @@ class QuantileDMatrix private[scala](
throw new XGBoostError("QuantileDMatrix does not support setGroup.")

/**
* Set label of DMatrix from cuda array interface
* Set label of DMatrix from array interface
*/
@throws(classOf[XGBoostError])
override def setLabel(column: Column): Unit =
Expand All @@ -104,4 +104,9 @@ class QuantileDMatrix private[scala](
override def setBaseMargin(column: Column): Unit =
throw new XGBoostError("QuantileDMatrix does not support setBaseMargin.")

@throws(classOf[XGBoostError])
override def setQueryId(column: Column): Unit = {
throw new XGBoostError("QuantileDMatrix does not support setQueryId.")
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import ml.dmlc.xgboost4j.java.CudfColumnBatch
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix}
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol

/**
Expand Down Expand Up @@ -119,15 +120,16 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
val nthread = estimator.getNthread
val missing = estimator.getMissing

/** build QuantilDMatrix on the executor side */
/** build QuantileDMatrix on the executor side */
def buildQuantileDMatrix(iter: Iterator[Table]): QuantileDMatrix = {
val colBatchIter = iter.map { table =>
withResource(new GpuColumnBatch(table)) { batch =>
new CudfColumnBatch(
batch.select(indices.featureIds.get),
batch.select(indices.labelId),
batch.select(indices.weightId.getOrElse(-1)),
batch.select(indices.marginId.getOrElse(-1)));
batch.select(indices.marginId.getOrElse(-1)),
batch.select(indices.groupId.getOrElse(-1)));
}
}
new QuantileDMatrix(colBatchIter, missing, maxBin, nthread)
Expand All @@ -150,16 +152,6 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
)
}

/** Executes the provided code block and then closes the resource */
def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} finally {
r.close()
}
}


override def transform[M <: XGBoostModel[M]](model: XGBoostModel[M],
dataset: Dataset[_]): DataFrame = {
val sc = dataset.sparkSession.sparkContext
Expand Down Expand Up @@ -226,7 +218,7 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
throw new RuntimeException("Something wrong for feature indices")
}
try {
val cudfColumnBatch = new CudfColumnBatch(featureTable, null, null, null)
val cudfColumnBatch = new CudfColumnBatch(featureTable, null, null, null, null)
val dm = new DMatrix(cudfColumnBatch, missing, nThread)
if (dm == null) {
Iterator.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
import java.util.List;
import java.util.Map;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.Table;
import ai.rapids.cudf.*;
import junit.framework.TestCase;
import org.junit.Test;

Expand All @@ -36,7 +35,27 @@ public class BoosterTest {

@Test
public void testBooster() throws XGBoostError {
String resourcePath = getClass().getResource("/binary.train.parquet").getFile();
String trainingDataPath = "../../demo/data/veterans_lung_cancer.csv";
Schema schema = Schema.builder()
.column(DType.FLOAT32, "A")
.column(DType.FLOAT32, "B")
.column(DType.FLOAT32, "C")
.column(DType.FLOAT32, "D")

.column(DType.FLOAT32, "E")
.column(DType.FLOAT32, "F")
.column(DType.FLOAT32, "G")
.column(DType.FLOAT32, "H")

.column(DType.FLOAT32, "I")
.column(DType.FLOAT32, "J")
.column(DType.FLOAT32, "K")
.column(DType.FLOAT32, "L")

.column(DType.FLOAT32, "label")
.build();
CSVOptions opts = CSVOptions.builder()
.hasHeader().build();

int maxBin = 16;
int round = 10;
Expand All @@ -53,44 +72,46 @@ public void testBooster() throws XGBoostError {
}
};

try (Table table = Table.readParquet(new File(resourcePath))) {
ColumnVector[] features = new ColumnVector[6];
for (int i = 0; i < 6; i++) {
features[i] = table.getColumn(i);
try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) {
ColumnVector[] df = new ColumnVector[10];
// exclude the first two columns, they are label bounds and contain inf.
for (int i = 2; i < 12; ++i) {
df[i - 2] = tmpTable.getColumn(i);
}

try (Table X = new Table(features)) {
try (Table X = new Table(df);) {
ColumnVector[] labels = new ColumnVector[1];
labels[0] = table.getColumn(6);
labels[0] = tmpTable.getColumn(12);

try (Table y = new Table(labels)) {
try (Table y = new Table(labels);) {

CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null);
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null, null);
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));

// train XGBoost Booster base on DMatrix
//set watchList
HashMap<String, DMatrix> watches = new HashMap<>();

DMatrix dMatrix1 = new DMatrix(batch, Float.NaN, 1);
dMatrix1.setLabel(labelColumn);
watches.put("train", dMatrix1);
Booster model1 = XGBoost.train(dMatrix1, paramMap, round, watches, null, null);

// train XGBoost Booster base on QuantileDMatrix
List<ColumnBatch> tables = new LinkedList<>();
tables.add(batch);
DMatrix incrementalDMatrix = new QuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
//set watchList
HashMap<String, DMatrix> watches1 = new HashMap<>();
watches1.put("train", incrementalDMatrix);
Booster model2 = XGBoost.train(incrementalDMatrix, paramMap, round, watches1, null, null);

float[][] predicat1 = model1.predict(dMatrix1);
float[][] predicat2 = model2.predict(dMatrix1);

for (int i = 0; i < table.getRowCount(); i++) {
for (int i = 0; i < tmpTable.getRowCount(); i++) {
TestCase.assertTrue(predicat1[i][0] - predicat2[i][0] < 1e-6);
}
}
}
}
}

}
Loading

0 comments on commit b1d0ef7

Please sign in to comment.