Skip to content

Commit

Permalink
fixup! discojs/validation: argmax for multiclass preds
Browse files Browse the repository at this point in the history
  • Loading branch information
s314cy committed Nov 21, 2022
1 parent ee982e6 commit 2280e8b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
8 changes: 2 additions & 6 deletions discojs/discojs-core/src/validation/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ export class Validator {
this.size += xs.shape[0]

// Get labels from one hot encoding
const preds = List(oneHotPreds.reshape([-1, this.classes]).arraySync() as number[][])
.map(List)
.map((ps) => ps.indexOf(ps.max() ?? 0))
const labels = List(ys.reshape([-1, this.classes]).arraySync() as number[][])
.map(List)
.map((ps) => ps.indexOf(1))
const preds = List(oneHotPreds.reshape([-1, this.classes]).argMax(1).arraySync() as number[])
const labels = List(ys.reshape([-1, this.classes]).argMax(1).arraySync() as number[])

// Keep track of prediction results for the confusion matrix
this.preds = this.preds.push(...preds)
Expand Down
32 changes: 28 additions & 4 deletions web-client/src/components/validation/Validator.vue
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Test
</template>
</ButtonCard>

<!-- display the chart -->
<div class="p-4 mx-auto lg:w-1/2 h-full bg-white rounded-md">
<!-- header -->
Expand All @@ -42,6 +43,28 @@
:series="[{ data: accuracyData }]"
/>
</div>

<IconCard v-if="validator !== undefined">
<template #title>
Validation Metrics
</template>
<template #content>
<!-- html table here -->
<div
v-for="(predictions, i) in validator.confusionMatrix"
:key="i"
>
<div
v-for="(preds, j) in predictions"
:key="j"
>
{{ i }}, {{ j }}, {{ preds }},
{{ task.trainingInformation.LABEL_LIST[i] }},
{{ task.trainingInformation.LABEL_LIST[j] }}
</div>
</div>
</template>
</IconCard>
</div>
</template>
<script lang="ts" setup>
Expand All @@ -55,6 +78,7 @@ import { useValidationStore } from '@/store/validation'
import { chartOptions } from '@/charts'
import { useToaster } from '@/composables/toaster'
import ButtonCard from '@/components/containers/ButtonCard.vue'
import IconCard from '@/components/containers/IconCard.vue'
const { useIndexedDB } = storeToRefs(useMemoryStore())
const toaster = useToaster()
Expand All @@ -66,19 +90,19 @@ interface Props {
}
const props = defineProps<Props>()
const validator = ref<Validator>(undefined)
const validator = ref<Validator | undefined>(undefined)
const memory = computed<Memory>(() => useIndexedDB ? new browser.IndexedDB() : new EmptyMemory())
const accuracyData = computed<number[]>(() => {
const r = validator.value?.accuracyData()
const r = validator.value?.accuracyData
return r !== undefined ? r.toArray() : [0]
})
const currentAccuracy = computed<string>(() => {
const r = validator.value?.accuracy()
const r = validator.value?.accuracy
return r !== undefined ? (r * 100).toFixed(2) : '0'
})
const visitedSamples = computed<number>(() => {
const r = validator.value?.visitedSamples()
const r = validator.value?.visitedSamples
return r !== undefined ? r : 0
})
Expand Down

0 comments on commit 2280e8b

Please sign in to comment.