Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ui): Evaluations -- normalize radar data, show real values for bar charts #3199

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export const SummaryPlots: React.FC<{
state: EvaluationComparisonState;
setSelectedMetrics: (newModel: Record<string, boolean>) => void;
}> = ({state, setSelectedMetrics}) => {
const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics(state);
const {radarData, allMetricNames} = usePlotDataFromMetrics(state);
const {selectedMetrics} = state;

// Initialize selectedMetrics if null
Expand Down Expand Up @@ -237,11 +237,10 @@ const useFilteredData = (
return data;
}, [radarData, selectedMetrics]);

function getMetricValuesFromRadarData(radarData: RadarPlotData): {
function getMetricValuesMap(radarData: RadarPlotData): {
[metric: string]: number[];
} {
const metricValues: {[metric: string]: number[]} = {};
// Gather all values for each metric
Object.values(radarData).forEach(callData => {
Object.entries(callData.metrics).forEach(([metric, value]) => {
if (!metricValues[metric]) {
Expand All @@ -253,37 +252,54 @@ function getMetricValuesFromRadarData(radarData: RadarPlotData): {
return metricValues;
}

function getMetricMinsFromRadarData(radarData: RadarPlotData): {
[metric: string]: number;
function normalizeMetricValues(values: number[]): {
normalizedValues: number[];
normalizer: number;
} {
const metricValues = getMetricValuesFromRadarData(radarData);
const metricMins: {[metric: string]: number} = {};
Object.entries(metricValues).forEach(([metric, values]) => {
metricMins[metric] = Math.min(...values);
});
return metricMins;
const min = Math.min(...values);
const max = Math.max(...values);

if (min === max) {
return {
normalizedValues: values.map(() => 0.5),
normalizer: 1,
};
}

// Handle negative values by shifting
const shiftedValues = min < 0 ? values.map(v => v - min) : values;
const maxValue = min < 0 ? max - min : max;

const maxPower = Math.ceil(Math.log2(maxValue));
const normalizer = Math.pow(2, maxPower);

return {
normalizedValues: shiftedValues.map(v => v / normalizer),
normalizer,
};
}

function normalizeDataForRadarPlot(radarData: RadarPlotData): RadarPlotData {
const metricMins = getMetricMinsFromRadarData(radarData);
function normalizeDataForRadarPlot(
radarDataOriginal: RadarPlotData
): RadarPlotData {
const radarData = Object.fromEntries(
Object.entries(radarDataOriginal).map(([callId, callData]) => [
callId,
{...callData, metrics: {...callData.metrics}},
])
);

const normalizedData: RadarPlotData = {};
Object.entries(radarData).forEach(([callId, callData]) => {
normalizedData[callId] = {
name: callData.name,
color: callData.color,
metrics: {},
};
const metricValues = getMetricValuesMap(radarData);

Object.entries(callData.metrics).forEach(([metric, value]) => {
const min = metricMins[metric];
// Only shift values if there are negative values
const normalizedValue = min < 0 ? value - min : value;
normalizedData[callId].metrics[metric] = normalizedValue;
// Normalize each metric independently
Object.entries(metricValues).forEach(([metric, values]) => {
const {normalizedValues} = normalizeMetricValues(values);
Object.values(radarData).forEach((callData, index) => {
callData.metrics[metric] = normalizedValues[index];
});
});

return normalizedData;
return radarData;
}

const useBarPlotData = (filteredData: RadarPlotData) =>
Expand Down Expand Up @@ -317,7 +333,9 @@ const useBarPlotData = (filteredData: RadarPlotData) =>
type: 'bar',
y: metricBin.values,
x: metricBin.callIds,
text: metricBin.values.map(value => value.toFixed(3)),
text: metricBin.values.map(value =>
Number.isInteger(value) ? value.toString() : value.toFixed(3)
),
textposition: 'outside',
textfont: {size: 14, color: 'black'},
name: metric,
Expand Down Expand Up @@ -408,16 +426,7 @@ const usePaginatedPlots = (
return {plotsToShow, totalPlots, startIndex, endIndex, totalPages};
};

function normalizeValues(values: Array<number | undefined>): number[] {
// find the max value
// find the power of 2 that is greater than the max value
// divide all values by that power of 2
const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[]));
const maxPower = Math.ceil(Math.log2(maxVal));
return values.map(val => (val ? val / 2 ** maxPower : 0));
}

const useNormalizedPlotDataFromMetrics = (
const usePlotDataFromMetrics = (
state: EvaluationComparisonState
): {radarData: RadarPlotData; allMetricNames: Set<string>} => {
const compositeMetrics = useMemo(() => {
Expand All @@ -428,7 +437,7 @@ const useNormalizedPlotDataFromMetrics = (
}, [state]);

return useMemo(() => {
const normalizedMetrics = Object.values(compositeMetrics)
const metrics = Object.values(compositeMetrics)
.map(scoreGroup => Object.values(scoreGroup.metrics))
.flat()
.map(metric => {
Expand All @@ -449,11 +458,8 @@ const useNormalizedPlotDataFromMetrics = (
return val;
}
});
const normalizedValues = normalizeValues(values);
const evalScores: {[evalCallId: string]: number | undefined} =
Object.fromEntries(
callIds.map((key, i) => [key, normalizedValues[i]])
);
Object.fromEntries(callIds.map((key, i) => [key, values[i]]));

const metricLabel = flattenedDimensionPath(
Object.values(metric.scorerRefs)[0].metric
Expand All @@ -472,7 +478,7 @@ const useNormalizedPlotDataFromMetrics = (
name: evalCall.name,
color: evalCall.color,
metrics: Object.fromEntries(
normalizedMetrics.map(metric => {
metrics.map(metric => {
return [
metric.metricLabel,
metric.evalScores[evalCall.callId] ?? 0,
Expand All @@ -483,7 +489,7 @@ const useNormalizedPlotDataFromMetrics = (
];
})
);
const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel));
const allMetricNames = new Set(metrics.map(m => m.metricLabel));
return {radarData, allMetricNames};
}, [callIds, compositeMetrics, state.data.evaluationCalls]);
};
Loading