Skip to content

Commit

Permalink
udpate
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 26, 2024
1 parent 6440848 commit f230660
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils}
import org.apache.commons.logging.LogFactory
import org.apache.spark.TaskContext
import org.apache.spark.ml.functions.array_to_vector
import org.apache.spark.ml.param.Param
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.vectorized.ColumnarBatch

import ml.dmlc.xgboost4j.java.{CudfColumnBatch, GpuColumnBatch}
import ml.dmlc.xgboost4j.java.CudfColumnBatch
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix}
import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol

Expand Down Expand Up @@ -121,9 +120,9 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
/** build QuantilDMatrix on the executor side */
def buildQuantileDMatrix(iter: Iterator[Table]): QuantileDMatrix = {
val colBatchIter = iter.map { table =>
withResource(new GpuColumnBatch(table, null)) { batch =>
withResource(new GpuColumnBatch(table)) { batch =>
new CudfColumnBatch(
batch.select(indices.featureIds.get.map(Integer.valueOf).asJava),
batch.select(indices.featureIds.get),
batch.select(indices.labelId),
batch.select(indices.weightId.getOrElse(-1)),
batch.select(indices.marginId.getOrElse(-1)));
Expand Down Expand Up @@ -219,9 +218,8 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
if (tableIters.hasNext) {
val dataTypes = originalSchema.fields.map(x => x.dataType)
iter = withResource(tableIters.next()) { table =>
val gpuColumnBatch = new GpuColumnBatch(table, originalSchema)
// Create DMatrix
val featureTable = gpuColumnBatch.select(featureIds.map(Integer.valueOf).asJava)
val featureTable = new GpuColumnBatch(table).select(featureIds)
if (featureTable == null) {
throw new RuntimeException("Something wrong for feature indices")
}
Expand Down Expand Up @@ -279,3 +277,23 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
model.postTransform(output, pred).toDF()
}
}

private class GpuColumnBatch(table: Table) extends AutoCloseable {

def select(index: Int): Table = {
select(Seq(index))
}

def select(indices: Seq[Int]): Table = {
if (!indices.forall(index => index < table.getNumberOfColumns && index >= 0)) {
return null;
}
new Table(indices.map(table.getColumn): _*)
}

override def close(): Unit = {
if (Option(table).isDefined) {
table.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,52 +140,43 @@ private float[] convertFloatTofloat(Float[]... datas) {

@Test
public void testMakingDMatrixViaArray() {
// ColumnVector child1 = ColumnVector.fromFloats(1, 2, 3, 4, 5, 6);
// ColumnVector child2 = ColumnVector.fromFloats(11, 12, 13, 14, 15, 16);
// ColumnVector list = ColumnVector.makeList(child1, child2);
// child2.close();
// child1.close();

Float[][] features1 = {
{1.0f, 12.0f},
{2.0f, 13.0f},
null,
{4.0f, null},
{5.0f, 16.0f}
};

Float[] label1 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f};

Table X1 = new Table.TestBuilder().column(features1).build();
Table y1 = new Table.TestBuilder().column(label1).build();

// HostColumnVector hcv = X1.getColumn(0).copyToHost();

//
ColumnVector t = X1.getColumn(0);
ColumnView cv = t.getChildColumnView(0);
//
System.out.println("----");

Float[][] features2 = {
{6.0f, 17.0f},
{7.0f, 18.0f},
};
Float[] label2 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f};
Table X2 = new Table.TestBuilder().column(features2).build();
Table y2 = new Table.TestBuilder().column(label2).build();

List<ColumnBatch> tables = new LinkedList<>();
tables.add(new CudfColumnBatch(X1, y1, null, null));
tables.add(new CudfColumnBatch(X2, y2, null, null));

try {
DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
} catch (XGBoostError e) {
throw new RuntimeException(e);
}

System.out.println("--------------");
// Float[][] features1 = {
// {1.0f, 12.0f},
// {2.0f, 13.0f},
// null,
// {4.0f, null},
// {5.0f, 16.0f}
// };
//
// Float[] label1 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f};
//
// Table X1 = new Table.TestBuilder().column(features1).build();
// Table y1 = new Table.TestBuilder().column(label1).build();
//
// ColumnVector t = X1.getColumn(0);
// ColumnView cv = t.getChildColumnView(0);
// //
// System.out.println("----");
//
// Float[][] features2 = {
// {6.0f, 17.0f},
// {7.0f, 18.0f},
// };
// Float[] label2 = {0.0f, 1.0f, 0.0f, 1.0f, 0.0f};
// Table X2 = new Table.TestBuilder().column(features2).build();
// Table y2 = new Table.TestBuilder().column(label2).build();
//
// List<ColumnBatch> tables = new LinkedList<>();
// tables.add(new CudfColumnBatch(X1, y1, null, null));
// tables.add(new CudfColumnBatch(X2, y2, null, null));
//
// try {
// DMatrix dmat = new QuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
// } catch (XGBoostError e) {
// throw new RuntimeException(e);
// }
//
// System.out.println("--------------");

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class XXXXXSuite extends AnyFunSuite with GpuTestSuite {

val out = model.transform(df)
out.printSchema()
out.show(150, false)
out.show(150)
// model.write.overwrite().save("/tmp/model/")
// val loadedModel = XGBoostClassificationModel.load("/tmp/model")
// println(loadedModel.getNumRound)
Expand Down

0 comments on commit f230660

Please sign in to comment.