Skip to content

Commit

Permalink
fixup! discojs/validation: argmax for multiclass preds
Browse files Browse the repository at this point in the history
  • Loading branch information
s314cy committed Nov 21, 2022
1 parent ee982e6 commit 8e3c573
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions discojs/discojs-core/src/validation/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ export class Validator {
this.size += xs.shape[0]

// Get labels from one hot encoding
const preds = List(oneHotPreds.reshape([-1, this.classes]).arraySync() as number[][])
.map(List)
.map((ps) => ps.indexOf(ps.max() ?? 0))
const labels = List(ys.reshape([-1, this.classes]).arraySync() as number[][])
.map(List)
.map((ps) => ps.indexOf(1))
const preds = List(oneHotPreds.reshape([-1, this.classes]).argMax(1).arraySync() as number[])
const labels = List(ys.reshape([-1, this.classes]).argMax(1).arraySync() as number[])

// Keep track of prediction results for the confusion matrix
this.preds = this.preds.push(...preds)
Expand Down

0 comments on commit 8e3c573

Please sign in to comment.