-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to train AI using Tasks and their bounding boxes #8310
Changes from 12 commits
293c98f
d45ea84
a271981
675d369
bc2b55b
1c44bf0
0a5a046
f63afb8
a329a16
8ab7b3e
b9e0946
52a80ef
c37fe7f
d526816
d53eb2a
aafd50b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ import { | |||||
getTracingForAnnotationType, | ||||||
runTraining, | ||||||
} from "admin/admin_rest_api"; | ||||||
import { getAnnotationsForTask } from "admin/api/tasks"; | ||||||
import { | ||||||
Alert, | ||||||
Button, | ||||||
|
@@ -55,10 +56,12 @@ const FormItem = Form.Item; | |||||
// only the APIAnnotations of the given annotations to train on are loaded from the backend. | ||||||
// Thus, the code needs to handle both HybridTracing | APIAnnotation where APIAnnotation is missing some information. | ||||||
// Therefore, volumeTracings with the matching volumeTracingMags are needed to get more details on each volume annotation layer and its magnifications. | ||||||
// The userBoundingBoxes are needed for checking for equal bounding box sizes. As training on fallback data is supported and an annotation is not required to have VolumeTracings, | ||||||
// As the userBoundingBoxes should have multiple sizes of the smallest one, a check with a warning should be included. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
// As training on fallback data is supported and an annotation is not required to have VolumeTracings, | ||||||
// it is necessary to save userBoundingBoxes separately and not load them from volumeTracings entries to support skeleton only annotations. | ||||||
// Moreover, in case an annotations is a task, its task bounding box should also be used for training. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
// Note that a copy of the userBoundingBoxes is included in each volume and skeleton tracing of an annotation. Thus, it doesn't matter from which the userBoundingBoxes are taken. | ||||||
export type AnnotationInfoForAIJob<GenericAnnotation> = { | ||||||
export type AnnotationInfoForAITrainingJob<GenericAnnotation> = { | ||||||
annotation: GenericAnnotation; | ||||||
dataset: APIDataset; | ||||||
volumeTracings: VolumeTracing[]; | ||||||
|
@@ -175,8 +178,8 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid | |||||
getMagsForSegmentationLayer: (annotationId: string, layerName: string) => MagInfo; | ||||||
onClose: () => void; | ||||||
ensureSavedState?: (() => Promise<void>) | null; | ||||||
annotationInfos: Array<AnnotationInfoForAIJob<GenericAnnotation>>; | ||||||
onAddAnnotationsInfos?: (newItems: Array<AnnotationInfoForAIJob<APIAnnotation>>) => void; | ||||||
annotationInfos: Array<AnnotationInfoForAITrainingJob<GenericAnnotation>>; | ||||||
onAddAnnotationsInfos?: (newItems: Array<AnnotationInfoForAITrainingJob<APIAnnotation>>) => void; | ||||||
}) { | ||||||
const [form] = Form.useForm(); | ||||||
|
||||||
|
@@ -205,6 +208,10 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid | |||||
const magInfoForLayer: Array<MagInfo> = Form.useWatch(() => { | ||||||
return watcherFunctionRef.current(); | ||||||
}, form); | ||||||
const trainingAnnotationsInfo = Form.useWatch("trainingAnnotations", form) as Array<{ | ||||||
annotationId: string; | ||||||
mag: Vector3; | ||||||
}>; | ||||||
|
||||||
const [useCustomWorkflow, setUseCustomWorkflow] = React.useState(false); | ||||||
|
||||||
|
@@ -220,7 +227,6 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid | |||||
annotationId, | ||||||
groundTruthLayerName, | ||||||
).getMagList(); | ||||||
console.log("getintersectingmaglist", dataLayerMags, groundTruthLayerMags); | ||||||
|
||||||
return groundTruthLayerMags?.filter((groundTruthMag) => | ||||||
dataLayerMags?.find((mag) => V3.equals(mag, groundTruthMag)), | ||||||
|
@@ -275,12 +281,16 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid | |||||
modelCategory: AiModelCategory.EM_NEURONS, | ||||||
}; | ||||||
|
||||||
const userBoundingBoxes = annotationInfos.flatMap(({ userBoundingBoxes, annotation }) => | ||||||
userBoundingBoxes.map((box) => ({ | ||||||
const userBoundingBoxes = annotationInfos.flatMap(({ userBoundingBoxes, annotation }) => { | ||||||
const annotationId = "id" in annotation ? annotation.id : annotation.annotationId; | ||||||
return userBoundingBoxes.map((box) => ({ | ||||||
...box, | ||||||
annotationId: "id" in annotation ? annotation.id : annotation.annotationId, | ||||||
})), | ||||||
); | ||||||
annotationId: annotationId, | ||||||
trainingMag: trainingAnnotationsInfo?.find( | ||||||
(formInfo) => formInfo.annotationId === annotationId, | ||||||
)?.mag, | ||||||
})); | ||||||
}); | ||||||
|
||||||
const bboxesVoxelCount = _.sum( | ||||||
(userBoundingBoxes || []).map((bbox) => new BoundingBox(bbox.boundingBox).getVolume()), | ||||||
|
@@ -495,7 +505,7 @@ export function CollapsibleWorkflowYamlEditor({ | |||||
} | ||||||
|
||||||
function checkAnnotationsForErrorsAndWarnings<T extends HybridTracing | APIAnnotation>( | ||||||
annotationsWithDatasets: Array<AnnotationInfoForAIJob<T>>, | ||||||
annotationsWithDatasets: Array<AnnotationInfoForAITrainingJob<T>>, | ||||||
): { | ||||||
hasAnnotationErrors: boolean; | ||||||
errors: string[]; | ||||||
|
@@ -525,8 +535,12 @@ function checkAnnotationsForErrorsAndWarnings<T extends HybridTracing | APIAnnot | |||||
return { hasAnnotationErrors: false, errors: [] }; | ||||||
} | ||||||
|
||||||
const MIN_BBOX_EXTENT_IN_EACH_DIM = 32; | ||||||
function checkBoundingBoxesForErrorsAndWarnings( | ||||||
userBoundingBoxes: (UserBoundingBox & { annotationId: string })[], | ||||||
userBoundingBoxes: (UserBoundingBox & { | ||||||
annotationId: string; | ||||||
trainingMag: Vector3 | undefined; | ||||||
})[], | ||||||
): { | ||||||
hasBBoxErrors: boolean; | ||||||
hasBBoxWarnings: boolean; | ||||||
|
@@ -543,22 +557,59 @@ function checkBoundingBoxesForErrorsAndWarnings( | |||||
} | ||||||
// Find smallest bounding box dimensions | ||||||
const minDimensions = userBoundingBoxes.reduce( | ||||||
(min, { boundingBox: box }) => ({ | ||||||
x: Math.min(min.x, box.max[0] - box.min[0]), | ||||||
y: Math.min(min.y, box.max[1] - box.min[1]), | ||||||
z: Math.min(min.z, box.max[2] - box.min[2]), | ||||||
}), | ||||||
{ x: Number.POSITIVE_INFINITY, y: Number.POSITIVE_INFINITY, z: Number.POSITIVE_INFINITY }, | ||||||
(min, { boundingBox: box, trainingMag }) => { | ||||||
let bbox = new BoundingBox(box); | ||||||
if (trainingMag) { | ||||||
bbox = bbox.alignWithMag(trainingMag, "shrink"); | ||||||
} | ||||||
const size = bbox.getSize(); | ||||||
return { | ||||||
x: Math.min(min.x, size[0]), | ||||||
y: Math.min(min.y, size[1]), | ||||||
z: Math.min(min.z, size[2]), | ||||||
}; | ||||||
}, | ||||||
{ | ||||||
x: Number.POSITIVE_INFINITY, | ||||||
y: Number.POSITIVE_INFINITY, | ||||||
z: Number.POSITIVE_INFINITY, | ||||||
}, | ||||||
); | ||||||
|
||||||
// Validate minimum size and multiple requirements | ||||||
type BoundingBoxWithAnnotationId = { boundingBox: Vector6; name: string; annotationId: string }; | ||||||
type BoundingBoxWithAnnotationId = { | ||||||
boundingBox: Vector6; | ||||||
name: string; | ||||||
annotationId: string; | ||||||
}; | ||||||
const tooSmallBoxes: BoundingBoxWithAnnotationId[] = []; | ||||||
const nonMultipleBoxes: BoundingBoxWithAnnotationId[] = []; | ||||||
userBoundingBoxes.forEach(({ boundingBox: box, name, annotationId }) => { | ||||||
const arrayBox = computeArrayFromBoundingBox(box); | ||||||
const notMagAlignedBoundingBoxes: (BoundingBoxWithAnnotationId & { | ||||||
alignedBoundingBox: Vector6; | ||||||
})[] = []; | ||||||
userBoundingBoxes.forEach(({ boundingBox: box, name, annotationId, trainingMag }) => { | ||||||
const boundingBox = new BoundingBox(box); | ||||||
let arrayBox = computeArrayFromBoundingBox(box); | ||||||
if (trainingMag) { | ||||||
const alignedBoundingBox = boundingBox.alignWithMag(trainingMag, "shrink"); | ||||||
if (!alignedBoundingBox.equals(boundingBox)) { | ||||||
const alignedArrayBox = computeArrayFromBoundingBox(alignedBoundingBox); | ||||||
notMagAlignedBoundingBoxes.push({ | ||||||
boundingBox: arrayBox, | ||||||
name, | ||||||
annotationId, | ||||||
alignedBoundingBox: alignedArrayBox, | ||||||
}); | ||||||
// Update the arrayBox as the aligned version of the bounding box will be used for training. | ||||||
arrayBox = alignedArrayBox; | ||||||
} | ||||||
} | ||||||
const [_x, _y, _z, width, height, depth] = arrayBox; | ||||||
if (width < 10 || height < 10 || depth < 10) { | ||||||
if ( | ||||||
width < MIN_BBOX_EXTENT_IN_EACH_DIM || | ||||||
height < MIN_BBOX_EXTENT_IN_EACH_DIM || | ||||||
depth < MIN_BBOX_EXTENT_IN_EACH_DIM | ||||||
) { | ||||||
tooSmallBoxes.push({ boundingBox: arrayBox, name, annotationId }); | ||||||
} | ||||||
|
||||||
|
@@ -571,14 +622,25 @@ function checkBoundingBoxesForErrorsAndWarnings( | |||||
} | ||||||
}); | ||||||
|
||||||
if (notMagAlignedBoundingBoxes.length > 0) { | ||||||
hasBBoxWarnings = true; | ||||||
const notMagAlignedBoundingBoxesStrings = notMagAlignedBoundingBoxes.map( | ||||||
({ boundingBox, name, annotationId, alignedBoundingBox }) => | ||||||
`'${name}' of annotation ${annotationId}: ${boundingBox.join(", ")} -> ${alignedBoundingBox.join(", ")}`, | ||||||
); | ||||||
warnings.push( | ||||||
`The following bounding boxes are not aligned with the selected magnification. They will be automatically shrunk to be aligned with the magnification:\n${notMagAlignedBoundingBoxesStrings.join("\n")}`, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be helpful to include the actual magnification here, for example at the end of the first sentence. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the problem is that one can select different mags for different annotations. Therefore, I added the mag to the end of each bbox entry in the warning message. Here is what I currently have: And this is what I mean with an improved hierarchy structure: The following bounding boxes are not aligned with the selected magnification. They will be automatically shrunk to be aligned with the magnification:
- Annotation 677e8b3f5f0100a80b1dc6ae
- 'Bounding box 1' (3584, 3584, 1024, 64, 64, 64) will be 1792, 1792, 1024, 32, 32, 64 in mag 2, 2, 1
- 'Bounding box 2' (3648, 3648, 1088, 64, 64, 64) will be 1824, 1824, 1088, 32, 32, 64 in mag 2, 2, 1
- 'Bounding box 3' (3712, 3712, 1152, 64, 64, 64) will be 1856, 1856, 1152, 32, 32, 64 in mag 2, 2, 1 What do you think about this tree structured text? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like that a lot 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
); | ||||||
} | ||||||
|
||||||
const boxWithIdToString = ({ boundingBox, name, annotationId }: BoundingBoxWithAnnotationId) => | ||||||
`'${name}' of annotation ${annotationId}: ${boundingBox.join(", ")}`; | ||||||
|
||||||
if (tooSmallBoxes.length > 0) { | ||||||
hasBBoxWarnings = true; | ||||||
const tooSmallBoxesStrings = tooSmallBoxes.map(boxWithIdToString); | ||||||
warnings.push( | ||||||
`The following bounding boxes are not at least 10 Vx in each dimension which is suboptimal for the training:\n${tooSmallBoxesStrings.join("\n")}`, | ||||||
`The following bounding boxes are not at least ${MIN_BBOX_EXTENT_IN_EACH_DIM} Vx in each dimension which is suboptimal for the training:\n${tooSmallBoxesStrings.join("\n")}`, | ||||||
); | ||||||
} | ||||||
|
||||||
|
@@ -596,31 +658,44 @@ function checkBoundingBoxesForErrorsAndWarnings( | |||||
function AnnotationsCsvInput({ | ||||||
onAdd, | ||||||
}: { | ||||||
onAdd: (newItems: Array<AnnotationInfoForAIJob<APIAnnotation>>) => void; | ||||||
onAdd: (newItems: Array<AnnotationInfoForAITrainingJob<APIAnnotation>>) => void; | ||||||
}) { | ||||||
const [value, setValue] = useState(""); | ||||||
const onClickAdd = async () => { | ||||||
const newItems = []; | ||||||
const annotationIdsForTraining = []; | ||||||
const unfinishedTasks = []; | ||||||
|
||||||
const lines = value | ||||||
.split("\n") | ||||||
.map((line) => line.trim()) | ||||||
.filter((line) => line !== ""); | ||||||
for (const annotationUrlOrId of lines) { | ||||||
if (annotationUrlOrId.includes("/")) { | ||||||
newItems.push({ | ||||||
annotationId: annotationUrlOrId.split("/").at(-1) as string, | ||||||
}); | ||||||
for (const taskOrAnnotationIdOrUrl of lines) { | ||||||
if (taskOrAnnotationIdOrUrl.includes("/")) { | ||||||
annotationIdsForTraining.push(taskOrAnnotationIdOrUrl.split("/").at(-1) as string); | ||||||
} else { | ||||||
newItems.push({ | ||||||
annotationId: annotationUrlOrId, | ||||||
}); | ||||||
let isTask = true; | ||||||
try { | ||||||
const annotations = await getAnnotationsForTask(taskOrAnnotationIdOrUrl, { | ||||||
showErrorToast: false, | ||||||
}); | ||||||
const finishedAnnotations = annotations.filter(({ state }) => state === "Finished"); | ||||||
if (annotations.length > 0) { | ||||||
annotationIdsForTraining.push(...finishedAnnotations.map(({ id }) => id)); | ||||||
} else { | ||||||
unfinishedTasks.push(taskOrAnnotationIdOrUrl); | ||||||
} | ||||||
} catch (_e) { | ||||||
isTask = false; | ||||||
} | ||||||
if (!isTask) { | ||||||
annotationIdsForTraining.push(taskOrAnnotationIdOrUrl); | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
const newAnnotationsWithDatasets = await Promise.all( | ||||||
newItems.map(async (item) => { | ||||||
const annotation = await getAnnotationInformation(item.annotationId); | ||||||
annotationIdsForTraining.map(async (annotationId) => { | ||||||
const annotation = await getAnnotationInformation(annotationId); | ||||||
const dataset = await getDataset(annotation.datasetId); | ||||||
|
||||||
const volumeServerTracings: ServerVolumeTracing[] = await Promise.all( | ||||||
|
@@ -651,6 +726,16 @@ function AnnotationsCsvInput({ | |||||
); | ||||||
} | ||||||
} | ||||||
if (annotation.task?.boundingBox) { | ||||||
const largestId = Math.max(...userBoundingBoxes.map(({ id }) => id)); | ||||||
userBoundingBoxes.push({ | ||||||
name: "Task Bounding Box", | ||||||
boundingBox: Utils.computeBoundingBoxFromBoundingBoxObject(annotation.task.boundingBox), | ||||||
color: [0, 0, 0], | ||||||
isVisible: true, | ||||||
id: largestId + 1, | ||||||
}); | ||||||
} | ||||||
|
||||||
return { | ||||||
annotation, | ||||||
|
@@ -663,14 +748,18 @@ function AnnotationsCsvInput({ | |||||
}; | ||||||
}), | ||||||
); | ||||||
|
||||||
if (unfinishedTasks.length > 0) { | ||||||
Toast.warning( | ||||||
`The following tasks have no finished annotations: ${unfinishedTasks.join(", ")}`, | ||||||
); | ||||||
} | ||||||
onAdd(newAnnotationsWithDatasets); | ||||||
}; | ||||||
return ( | ||||||
<div> | ||||||
<FormItem | ||||||
name="annotationCsv" | ||||||
label="Annotations CSV" | ||||||
label="Annotations or Tasks CSV" | ||||||
hasFeedback | ||||||
initialValue={value} | ||||||
rules={[ | ||||||
|
@@ -693,7 +782,7 @@ function AnnotationsCsvInput({ | |||||
> | ||||||
<TextArea | ||||||
className="input-monospace" | ||||||
placeholder="annotationUrlOrId" | ||||||
placeholder="taskOrAnnotationIdOrUrl" | ||||||
autoSize={{ | ||||||
minRows: 6, | ||||||
}} | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!