Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ui): create annotation specs #3003

Merged
merged 10 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import {Box} from '@material-ui/core';
import React, {FC, useCallback, useState} from 'react';
import {z} from 'zod';

import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery';
import {TraceServerClient} from '../wfReactInterface/traceServerClient';
import {sanitizeObjectId} from '../wfReactInterface/traceServerDirectClient';
import {projectIdFromParts} from '../wfReactInterface/tsDataModelHooks';
import {ScorerFormProps} from './ScorerForms';
import {ZSForm} from './ZodSchemaForm';

const AnnotationScorerFormSchema = z.object({
Name: z.string().min(1),
Description: z.string().min(1),
Type: z.discriminatedUnion('type', [
z.object({
type: z.literal('boolean'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to capitalize Boolean (and other properties / literals) so that they render on the schema editor as such... depends on the display layer that you want.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly think it looks fine, because its not a label, it just gets rendered in the selector dropdown.

}),
z.object({
type: z.literal('number'),
min: z.number().optional().describe('Optional minimum value'),
max: z.number().optional().describe('Optional maximum value'),
}),
z.object({
type: z.literal('string'),
max_length: z
.number()
.optional()
.describe('Optional maximum length of the string'),
}),
z.object({
type: z.literal('enum'),
enum: z.array(z.string()).describe('List of options to choose from'),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here :)

}),
]),
});

export const AnnotationScorerForm: FC<
ScorerFormProps<z.infer<typeof AnnotationScorerFormSchema>>
> = ({data, onDataChange}) => {
const [config, setConfig] = useState(data);
const [isValid, setIsValid] = useState(false);

const handleConfigChange = useCallback(
(newConfig: any) => {
setConfig(newConfig);
onDataChange(isValid, newConfig);
},
[isValid, onDataChange]
);

const handleValidChange = useCallback(
(newIsValid: boolean) => {
setIsValid(newIsValid);
onDataChange(newIsValid, config);
},
[config, onDataChange]
);

return (
<Box>
<ZSForm
configSchema={AnnotationScorerFormSchema}
config={config ?? {}}
setConfig={handleConfigChange}
onValidChange={handleValidChange}
/>
</Box>
);
};

export const onAnnotationScorerSave = async (
entity: string,
project: string,
data: z.infer<typeof AnnotationScorerFormSchema>,
client: TraceServerClient
) => {
let type = data.Type.type;
if (type === 'enum') {
type = 'string';
}
return createBaseObjectInstance(client, 'AnnotationSpec', {
obj: {
project_id: projectIdFromParts({entity, project}),
object_id: sanitizeObjectId(data.Name),
val: {
name: data.Name,
description: data.Description,
json_schema: {
...data.Type,
type,
},
},
},
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@ import React, {FC, ReactNode, useCallback, useEffect, useState} from 'react';

import {TraceServerClient} from '../wfReactInterface/traceServerClient';
import {useGetTraceServerClientContext} from '../wfReactInterface/traceServerClientContext';
import * as AnnotationScorerForm from './AnnotationScorerForm';
import {AutocompleteWithLabel} from './FormComponents';
import * as LLMJudgeScorerForm from './LLMJudgeScorerForm';
import {
AnnotationScorerForm,
ProgrammaticScorerForm,
ScorerFormProps,
} from './ScorerForms';
import {ProgrammaticScorerForm, ScorerFormProps} from './ScorerForms';

const HUMAN_ANNOTATION_LABEL = 'Human annotation';
export const HUMAN_ANNOTATION_VALUE = 'ANNOTATION';
Expand Down Expand Up @@ -41,11 +38,8 @@ export const scorerTypeRecord: Record<ScorerType, ScorerTypeConfig<any>> = {
label: HUMAN_ANNOTATION_LABEL,
value: HUMAN_ANNOTATION_VALUE,
icon: IconNames.UsersTeam,
Component: AnnotationScorerForm,
onSave: async (entity, project, data, client) => {
// Implementation for saving annotation scorer
console.log('TODO: save annotation scorer', data);
},
Component: AnnotationScorerForm.AnnotationScorerForm,
onSave: AnnotationScorerForm.onAnnotationScorerSave,
},
LLM_JUDGE: {
label: LLM_JUDGE_LABEL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@ export interface ScorerFormProps<T> {
onDataChange: (isValid: boolean, data?: T) => void;
}

export const AnnotationScorerForm: FC<ScorerFormProps<any>> = ({
data,
onDataChange,
}) => {
// Implementation for annotation scorer form
return <div>Annotation Scorer Form</div>;
};

export const ProgrammaticScorerForm: FC<ScorerFormProps<any>> = ({
data,
onDataChange,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
InputLabel,
Tooltip,
} from '@material-ui/core';
import {Delete, Help} from '@mui/icons-material';
import {Help} from '@mui/icons-material';
import {Button} from '@wandb/weave/components/Button';
import React, {useEffect, useMemo, useState} from 'react';
import {z} from 'zod';
Expand Down Expand Up @@ -156,12 +156,15 @@ const NestedForm: React.FC<{
config: Record<string, any>;
setConfig: (config: Record<string, any>) => void;
path: string[];
}> = ({keyName, fieldSchema, config, setConfig, path}) => {
hideLabel?: boolean;
}> = ({keyName, fieldSchema, config, setConfig, path, hideLabel}) => {
const currentPath = [...path, keyName];
const currentValue = getNestedValue(config, currentPath);

const unwrappedSchema = unwrapSchema(fieldSchema);

console.log(typeof fieldSchema, fieldSchema);

if (unwrappedSchema instanceof z.ZodDiscriminatedUnion) {
return (
<DiscriminatedUnionField
Expand All @@ -180,7 +183,7 @@ const NestedForm: React.FC<{
<FormControl
fullWidth
style={{marginBottom: GAP_BETWEEN_ITEMS_PX + 'px'}}>
<Label label={keyName} />
{!hideLabel && <Label label={keyName} />}
<Box ml={2}>
<ZSForm
configSchema={unwrappedSchema as z.ZodObject<any>}
Expand Down Expand Up @@ -286,7 +289,7 @@ const NestedForm: React.FC<{

return (
<TextFieldWithLabel
label={keyName}
label={!hideLabel ? keyName : undefined}
type={fieldType}
value={currentValue ?? ''}
onChange={value => updateConfig(currentPath, value, config, setConfig)}
Expand Down Expand Up @@ -317,6 +320,7 @@ const ArrayField: React.FC<{
);
const minItems = unwrappedSchema._def.minLength?.value ?? 0;
const elementSchema = unwrappedSchema.element;
const fieldDescription = getFieldDescription(fieldSchema);

// Ensure the minimum number of items is always present
React.useEffect(() => {
Expand All @@ -331,47 +335,61 @@ const ArrayField: React.FC<{

return (
<FormControl fullWidth style={{marginBottom: GAP_BETWEEN_ITEMS_PX + 'px'}}>
<Label label={keyName} />
<Box display="flex" alignItems="center" justifyContent="space-between">
<Label label={keyName} />
{fieldDescription && (
<DescriptionTooltip description={fieldDescription} />
)}
</Box>
{arrayValue.map((item, index) => (
<Box
key={index}
display="flex"
flexDirection="column"
alignItems="flex-start"
mb={2}
sx={{
borderBottom: '1px solid',
p: 2,
style={{
width: '100%',
gap: 4,
alignItems: 'center',
height: '35px',
marginBottom: '4px',
}}>
<Box flexGrow={1} width="100%">
<NestedForm
keyName={`${index}`}
fieldSchema={elementSchema}
config={{[`${index}`]: item}}
setConfig={newItemConfig => {
const newArray = [...arrayValue];
newArray[index] = newItemConfig[`${index}`];
updateConfig(targetPath, newArray, config, setConfig);
}}
path={[]}
/>
</Box>
<Box mt={1}>
<IconButton
onClick={() =>
removeArrayItem(targetPath, index, config, setConfig)
}
disabled={arrayValue.length <= minItems}>
<Delete />
</IconButton>
<Box flexGrow={1} width="100%" display="flex" alignItems="center">
<Box flexGrow={1}>
<NestedForm
keyName={`${index}`}
fieldSchema={elementSchema}
config={{[`${index}`]: item}}
setConfig={newItemConfig => {
const newArray = [...arrayValue];
newArray[index] = newItemConfig[`${index}`];
updateConfig(targetPath, newArray, config, setConfig);
}}
path={[]}
hideLabel
/>
</Box>
<Box mb={2} ml={1}>
<Button
size="small"
variant="ghost"
icon="delete"
tooltip="Remove this entry"
disabled={arrayValue.length <= minItems}
onClick={() =>
removeArrayItem(targetPath, index, config, setConfig)
}
/>
</Box>
</Box>
</Box>
))}
<Button
variant="secondary"
onClick={() =>
addArrayItem(targetPath, elementSchema, config, setConfig)
}>
Add Item
Add item
</Button>
</FormControl>
);
Expand Down Expand Up @@ -619,23 +637,7 @@ const updateConfig = (
}
current = current[targetPath[i]];
}

// Convert OrderedRecord to plain object if necessary
if (
value &&
typeof value === 'object' &&
'keys' in value &&
'values' in value
) {
const plainObject: Record<string, any> = {};
value.keys.forEach((key: string) => {
plainObject[key] = value.values[key];
});
current[targetPath[targetPath.length - 1]] = plainObject;
} else {
current[targetPath[targetPath.length - 1]] = value;
}

current[targetPath[targetPath.length - 1]] = value;
setConfig(newConfig);
};

Expand Down Expand Up @@ -698,23 +700,32 @@ const NumberField: React.FC<{
const max =
(unwrappedSchema._def.checks.find(check => check.kind === 'max') as any)
?.value ?? undefined;
const fieldDescription = getFieldDescription(fieldSchema);

return (
<TextFieldWithLabel
label={keyName}
type="number"
value={(value ?? '').toString()}
onChange={newValue => {
const finalValue = newValue === '' ? undefined : Number(newValue);
if (
finalValue !== undefined &&
(finalValue < min || finalValue > max)
) {
return;
}
updateConfig(targetPath, finalValue, config, setConfig);
}}
/>
<Box display="flex" alignContent="center" justifyContent="space-between">
<TextFieldWithLabel
label={keyName}
type="number"
value={(value ?? '').toString()}
style={{width: '100%'}}
onChange={newValue => {
const finalValue = newValue === '' ? undefined : Number(newValue);
if (
finalValue !== undefined &&
(finalValue < min || finalValue > max)
) {
return;
}
updateConfig(targetPath, finalValue, config, setConfig);
}}
/>
{fieldDescription && (
<Box display="flex" alignItems="center" sx={{marginTop: '14px'}}>
<DescriptionTooltip description={fieldDescription} />
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this styling seems odd to me

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah because the icon gets centered in the whole div which is tall because of the TextFieldWithLabel being multi-line. We could add the description in TextFieldWithLabel but that breaks the pattern.

</Box>
)}
</Box>
);
};

Expand Down
Loading