Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 26, 2024
1 parent 591d204 commit d6d1d78
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ public String toJson() {
ObjectMapper mapper = new ObjectMapper();
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
try {
return mapper.writeValueAsString(this);
List<CudfColumn> objects = new ArrayList<>(1);
objects.add(this);
return mapper.writeValueAsString(objects);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.Table;
Expand All @@ -41,9 +43,18 @@ public class CudfColumnBatch extends ColumnBatch {
private final Table baseMarginTable;

private List<CudfColumn> features;
private CudfColumn label;
private CudfColumn weight;
private CudfColumn baseMargin;
private List<CudfColumn> label;
private List<CudfColumn> weight;
private List<CudfColumn> baseMargin;

private List<CudfColumn> initializeCudfColumns(Table table) {
assert table != null && table.getNumberOfColumns() > 0;

return IntStream.range(0, table.getNumberOfColumns())
.mapToObj(table::getColumn)
.map(CudfColumn::from)
.collect(Collectors.toList());
}

public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
Table baseMarginTable) {
Expand All @@ -52,42 +63,35 @@ public CudfColumnBatch(Table featureTable, Table labelTable, Table weightTable,
this.weightTable = weightTable;
this.baseMarginTable = baseMarginTable;

features = new ArrayList<>();
for (int index = 0; index < featureTable.getNumberOfColumns(); index++) {
ColumnVector cv = featureTable.getColumn(index);
features.add(CudfColumn.from(cv));
}

features = initializeCudfColumns(featureTable);
if (labelTable != null) {
assert labelTable.getNumberOfColumns() == 1;
label = CudfColumn.from(labelTable.getColumn(0));
label = initializeCudfColumns(labelTable);
}

if (weightTable != null) {
assert weightTable.getNumberOfColumns() == 1;
weight = CudfColumn.from(weightTable.getColumn(0));
weight = initializeCudfColumns(weightTable);
}

// TODO baseMargin should be an array for multi classification
if (baseMarginTable != null) {
assert baseMarginTable.getNumberOfColumns() == 1;
baseMargin = CudfColumn.from(baseMarginTable.getColumn(0));
baseMargin = initializeCudfColumns(baseMarginTable);
}
}

public List<CudfColumn> getFeatures() {
return features;
}

public CudfColumn getLabel() {
public List<CudfColumn> getLabel() {
return label;
}

public CudfColumn getWeight() {
public List<CudfColumn> getWeight() {
return weight;
}

public CudfColumn getBaseMargin() {
public List<CudfColumn> getBaseMargin() {
return baseMargin;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ public void testBooster() throws XGBoostError {
try (Table y = new Table(labels);) {

CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null);
System.out.println(batch.toJson());
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
System.out.println(labelColumn.toJson());

//set watchList
HashMap<String, DMatrix> watches = new HashMap<>();
Expand Down

0 comments on commit d6d1d78

Please sign in to comment.