Skip to content

Commit

Permalink
feat(ui): Evaluation-Scoped Leaderboard Tab (#2692)
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney authored Oct 16, 2024
1 parent b97d64d commit 1da0173
Show file tree
Hide file tree
Showing 14 changed files with 1,181 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ export const CompareEvaluationsPage: React.FC<
> = props => {
return (
<SimplePageLayout
title="Compare Evaluations"
title={
props.evaluationCallIds.length === 1
? 'Evaluation Results'
: 'Compare Evaluations'
}
hideTabsIfSingle
tabs={[
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {Box} from '@material-ui/core';
import React, {useMemo, useState} from 'react';
import {useDeepMemo} from '@wandb/weave/hookUtils';
import React, {useEffect, useMemo, useState} from 'react';

import {WeaveLoader} from '../../../../../../common/components/WeaveLoader';
import {LinearProgress} from '../../../../../LinearProgress';
Expand Down Expand Up @@ -64,9 +65,14 @@ export const CompareEvaluationsProvider: React.FC<{
selectedInputDigest,
children,
}) => {
const initialEvaluationCallIdsMemo = useDeepMemo(initialEvaluationCallIds);
const [evaluationCallIds, setEvaluationCallIds] = useState(
initialEvaluationCallIds
initialEvaluationCallIdsMemo
);
useEffect(() => {
setEvaluationCallIds(initialEvaluationCallIdsMemo);
}, [initialEvaluationCallIdsMemo]);

const initialState = useEvaluationComparisonState(
entity,
project,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import {Box} from '@mui/material';
import React from 'react';

import {LeaderboardGrid} from '../views/Leaderboard/LeaderboardGrid';
import {useLeaderboardData} from '../views/Leaderboard/query/hookAdapters';

type EvaluationLeaderboardTabProps = {
entity: string;
project: string;
evaluationObjectName: string;
evaluationObjectVersion: string;
};

export const EvaluationLeaderboardTab: React.FC<
EvaluationLeaderboardTabProps
> = props => {
const {entity, project, evaluationObjectName, evaluationObjectVersion} =
props;

const {loading, data} = useLeaderboardData(entity, project, {
sourceEvaluations: [
{
name: evaluationObjectName,
version: evaluationObjectVersion,
},
],
});

return (
<Box
display="flex"
flexDirection="row"
height="100%"
flexGrow={1}
overflow="hidden">
<LeaderboardGrid
entity={entity}
project={project}
loading={loading}
data={data}
/>
</Box>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
SimpleKeyValueTable,
SimplePageLayoutWithHeader,
} from './common/SimplePageLayout';
import {EvaluationLeaderboardTab} from './LeaderboardTab';
import {TabUseDataset} from './TabUseDataset';
import {TabUseModel} from './TabUseModel';
import {TabUseObject} from './TabUseObject';
Expand All @@ -45,6 +46,7 @@ type ObjectIconProps = {
const OBJECT_ICONS: Record<KnownBaseObjectClassType, IconName> = {
Model: 'model',
Dataset: 'table',
Evaluation: 'benchmark-square',
};
const ObjectIcon = ({baseObjectClass}: ObjectIconProps) => {
if (baseObjectClass in OBJECT_ICONS) {
Expand Down Expand Up @@ -121,6 +123,9 @@ const ObjectVersionPageInner: React.FC<{
if (objectVersion.baseObjectClass === 'Model') {
return 'Model';
}
if (objectVersion.baseObjectClass === 'Evaluation') {
return 'Evaluation';
}
return null;
}, [objectVersion.baseObjectClass]);
const refUri = objectVersionKeyToRefUri(objectVersion);
Expand Down Expand Up @@ -181,6 +186,13 @@ const ObjectVersionPageInner: React.FC<{
}, [viewerData]);

const isDataset = baseObjectClass === 'Dataset' && refExtra == null;
const isEvaluation = baseObjectClass === 'Evaluation' && refExtra == null;
const evalHasCalls = (consumingCalls.result?.length ?? 0) > 0;
const evalHasCallsLoading = consumingCalls.loading;

if (isEvaluation && evalHasCallsLoading) {
return <CenteredAnimatedLoader />;
}

return (
<SimplePageLayoutWithHeader
Expand Down Expand Up @@ -264,6 +276,21 @@ const ObjectVersionPageInner: React.FC<{
// },
// ]}
tabs={[
...(isEvaluation && evalHasCalls
? [
{
label: 'Leaderboard',
content: (
<EvaluationLeaderboardTab
entity={entityName}
project={projectName}
evaluationObjectName={objectName}
evaluationObjectVersion={objectVersion.versionHash}
/>
),
},
]
: []),
{
label: isDataset ? 'Rows' : 'Values',
content: (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,11 @@ const ObjectVersionsTable: React.FC<{
},
renderCell: cellParams => {
const category = cellParams.value;
if (category === 'Model' || category === 'Dataset') {
if (
category === 'Model' ||
category === 'Dataset' ||
category === 'Evaluation'
) {
return <TypeVersionCategoryChip baseObjectClass={category} />;
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ export const SimplePageLayout: FC<{
backgroundColor: 'white',
pb: 0,
height: 44,

width: '100%',
borderBottom: '1px solid #e0e0e0',
display: 'flex',
flexDirection: 'row',
alignItems: 'center',
Expand All @@ -78,7 +79,7 @@ export const SimplePageLayout: FC<{
<Box
sx={{
height: 44,
flex: '0 0 44px',
flex: '1 0 44px',
display: 'flex',
flexDirection: 'row',
alignItems: 'center',
Expand Down Expand Up @@ -198,6 +199,7 @@ export const SimplePageLayoutWithHeader: FC<{
<Box
sx={{
height: 44,
width: '100%',
flex: '0 0 44px',
display: 'flex',
flexDirection: 'row',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {KnownBaseObjectClassType} from '../wfReactInterface/wfDataModelHooksInte
const colorMap: Record<KnownBaseObjectClassType, TagColorName> = {
Model: 'blue',
Dataset: 'green',
Evaluation: 'cactus',
};

export const TypeVersionCategoryChip: React.FC<{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ export const OP_CATEGORIES = [
'evaluate',
'tune',
] as const;
export const KNOWN_BASE_OBJECT_CLASSES = ['Model', 'Dataset'] as const;
export const KNOWN_BASE_OBJECT_CLASSES = [
'Model',
'Dataset',
'Evaluation',
] as const;
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ interface TraceObjectsFilter {
export type TraceObjQueryReq = {
project_id: string;
filter?: TraceObjectsFilter;
limit?: number;
offset?: number;
sort_by?: SortBy[];
metadata_only?: boolean;
};

export interface TraceObjSchema {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@

import _ from 'lodash';
import {sum} from 'lodash';
import {useEffect, useMemo, useState} from 'react';
import {useEffect, useMemo, useRef, useState} from 'react';

import {WB_RUN_COLORS} from '../../../../../../common/css/color.styles';
import {useDeepMemo} from '../../../../../../hookUtils';
Expand Down Expand Up @@ -107,6 +107,8 @@ export const useEvaluationComparisonData = (
const getTraceServerClient = useGetTraceServerClientContext();
const [data, setData] = useState<EvaluationComparisonData | null>(null);
const evaluationCallIdsMemo = useDeepMemo(evaluationCallIds);
const evaluationCallIdsRef = useRef(evaluationCallIdsMemo);

useEffect(() => {
setData(null);
let mounted = true;
Expand All @@ -117,6 +119,7 @@ export const useEvaluationComparisonData = (
evaluationCallIdsMemo
).then(dataRes => {
if (mounted) {
evaluationCallIdsRef.current = evaluationCallIdsMemo;
setData(dataRes);
}
});
Expand All @@ -126,11 +129,14 @@ export const useEvaluationComparisonData = (
}, [entity, evaluationCallIdsMemo, project, getTraceServerClient]);

return useMemo(() => {
if (data == null) {
if (
data == null ||
evaluationCallIdsRef.current !== evaluationCallIdsMemo
) {
return {loading: true, result: null};
}
return {loading: false, result: data};
}, [data]);
}, [data, evaluationCallIdsMemo]);
};

/**
Expand Down
Loading

0 comments on commit 1da0173

Please sign in to comment.