Skip to content

Commit

Permalink
fix test (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 26, 2024
1 parent f1df971 commit 291f2cd
Showing 1 changed file with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,12 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
}
}

// TODO .... why rowNum is 5, and non missing = 9
test("build RDD Watches") {
withGpuSparkSession() { spark =>
import spark.implicits._

// dataPoint -> (missing, rowNum, nonMissing)
Map(0.0f -> (0.0f, 4, 8), Float.NaN -> (0.0f, 5, 10)).foreach {
Map(0.0f -> (0.0f, 5, 9), Float.NaN -> (0.0f, 5, 9)).foreach {
case (data, (missing, expectedRowNum, expectedNonMissing)) =>
val df = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
Expand Down Expand Up @@ -174,13 +173,12 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
assert(labels.sorted === Array(0.0f, 1.0f, 0.0f, 0.0f, 1.0f).sorted)
assert(weight.sorted === Array(1.0f, 2.0f, 5.0f, 6.0f, 7.0f).sorted)
assert(margins.sorted === Array(2.0f, 3.0f, 6.0f, 7.0f, 8.0f).sorted)
// assert(rowNumber.sum === expectedRowNum)
assert(rowNumber.sum === expectedRowNum)
assert(nonMissing.sum === expectedNonMissing)
}
}
}

// TODO .... why rowNum is 5, and non missing = 9
test("build RDD Watches with Eval") {
withGpuSparkSession() { spark =>
import spark.implicits._
Expand All @@ -191,7 +189,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
).toDF("c1", "c2", "weight", "margin", "label", "other")

// dataPoint -> (missing, rowNum, nonMissing)
Map(0.0f -> (0.0f, 4, 8), Float.NaN -> (0.0f, 5, 10)).foreach {
Map(0.0f -> (0.0f, 5, 9), Float.NaN -> (0.0f, 5, 9)).foreach {
case (data, (missing, expectedRowNum, expectedNonMissing)) =>
val eval = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
Expand Down Expand Up @@ -240,7 +238,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
assert(labels.sorted === Array(0.0f, 1.0f, 0.0f, 0.0f, 1.0f).sorted)
assert(weight.sorted === Array(1.0f, 2.0f, 5.0f, 6.0f, 7.0f).sorted)
assert(margins.sorted === Array(2.0f, 3.0f, 6.0f, 7.0f, 8.0f).sorted)
// assert(rowNumber.sum === expectedRowNum)
assert(rowNumber.sum === expectedRowNum)
assert(nonMissing.sum === expectedNonMissing)
}
}
Expand Down

0 comments on commit 291f2cd

Please sign in to comment.