Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to train AI using Tasks and their bounding boxes #8310

Merged
merged 16 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ For upgrade instructions, please check the [migration guide](MIGRATIONS.released
- Renamed "resolution" to "magnification" in more places within the codebase, including local variables. [#8168](https://github.com/scalableminds/webknossos/pull/8168)
- Layer names are now allowed to contain `$` as special characters. [#8241](https://github.com/scalableminds/webknossos/pull/8241)
- Datasets can now be renamed and can have duplicate names. [#8075](https://github.com/scalableminds/webknossos/pull/8075)
- Starting an AI training job using multiple annotations now supports inputting task-IDs and considers their task bounding boxes. [#8310](https://github.com/scalableminds/webknossos/pull/8310)
- Improved the default colors for skeleton trees. [#8228](https://github.com/scalableminds/webknossos/pull/8228)
- Allowed to train an AI model using differently sized bounding boxes. We recommend all bounding boxes to have equal dimensions or to have dimensions which are multiples of the smallest bounding box. [#8222](https://github.com/scalableminds/webknossos/pull/8222)
- Within the bounding box tool, the cursor updates immediately after pressing `ctrl`, indicating that a bounding box can be moved instead of resized. [#8253](https://github.com/scalableminds/webknossos/pull/8253)
Expand Down
2 changes: 1 addition & 1 deletion app/controllers/AiModelController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class AiModelController @Inject()(
)
existingAiModelsCount <- aiModelDAO.countByNameAndOrganization(request.body.name,
request.identity._organization)
_ <- bool2Fox(existingAiModelsCount == 0) ?~> "model.nameInUse"
_ <- bool2Fox(existingAiModelsCount == 0) ?~> "aiModel.nameInUse"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

newTrainingJob <- jobService
.submitJob(jobCommand, commandArgs, request.identity, dataStore.name) ?~> "job.couldNotRunTrainModel"
newAiModel = AiModel(
Expand Down
1 change: 1 addition & 0 deletions conf/messages
Original file line number Diff line number Diff line change
Expand Up @@ -373,5 +373,6 @@ shortLink.notFound=No shortlink with this key could be found
aiModel.delete.referencedByInferences=Cannot delete AI models that are referenced by existing inferences.
aiModel.notFound=Could not find requested AI model.
aiModel.training.zeroAnnotations=Need at least one training annotation for model training.
aiModel.nameInUse=The AI model name is already in use. Please choose a different name.

aiInference.notFound=Could not find requested AI inference.
7 changes: 5 additions & 2 deletions frontend/javascripts/admin/api/tasks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ export async function requestTask(): Promise<APIAnnotationWithTask> {
const { messages: _messages, ...task } = taskWithMessages;
return task;
}
export function getAnnotationsForTask(taskId: string): Promise<Array<APIAnnotation>> {
return Request.receiveJSON(`/api/tasks/${taskId}/annotations`);
export function getAnnotationsForTask(
taskId: string,
options?: RequestOptions,
): Promise<Array<APIAnnotation>> {
return Request.receiveJSON(`/api/tasks/${taskId}/annotations`, options);
}
export function deleteTask(taskId: string): Promise<void> {
return Request.receiveJSON(`/api/tasks/${taskId}`, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ function TaskTypeCreateView({ taskTypeId, history }: Props) {
valuePropName="checked"
>
<Checkbox>
Allow Volume Interpolation
Allow Volume Interpolation{" "}
<Tooltip
title="When enabled, it suffices to only label every 2nd slice. The skipped slices will be filled automatically by interpolating between the labeled slices."
placement="right"
Expand Down
7 changes: 5 additions & 2 deletions frontend/javascripts/admin/voxelytics/ai_model_list_view.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ import { JobState } from "admin/job/job_list_view";
import { Link } from "react-router-dom";
import { useGuardedFetch } from "libs/react_helpers";
import { PageNotAvailableToNormalUser } from "components/permission_enforcer";
import { type AnnotationInfoForAIJob, TrainAiModelTab } from "oxalis/view/jobs/train_ai_model";
import {
type AnnotationInfoForAITrainingJob,
TrainAiModelTab,
} from "oxalis/view/jobs/train_ai_model";
import { getMagInfo, getSegmentationLayerByName } from "oxalis/model/accessors/dataset_accessor";
import type { Vector3 } from "oxalis/constants";
import type { Key } from "react";
Expand Down Expand Up @@ -106,7 +109,7 @@ export default function AiModelListView() {

function TrainNewAiJobModal({ onClose }: { onClose: () => void }) {
const [annotationInfosForAiJob, setAnnotationInfosForAiJob] = useState<
AnnotationInfoForAIJob<APIAnnotation>[]
AnnotationInfoForAITrainingJob<APIAnnotation>[]
>([]);

const getMagsForSegmentationLayer = (annotationId: string, layerName: string) => {
Expand Down
153 changes: 117 additions & 36 deletions frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import { computeArrayFromBoundingBox } from "libs/utils";
import { MagSelectionFormItem } from "components/mag_selection";
import { MagInfo } from "oxalis/model/helpers/mag_info";
import { V3 } from "libs/mjs";
import { getAnnotationsForTask } from "admin/api/tasks";

const { TextArea } = Input;
const FormItem = Form.Item;
Expand All @@ -55,10 +56,12 @@ const FormItem = Form.Item;
// only the APIAnnotations of the given annotations to train on are loaded from the backend.
// Thus, the code needs to handle both HybridTracing | APIAnnotation where APIAnnotation is missing some information.
// Therefore, volumeTracings with the matching volumeTracingMags are needed to get more details on each volume annotation layer and its magnifications.
// The userBoundingBoxes are needed for checking for equal bounding box sizes. As training on fallback data is supported and an annotation is not required to have VolumeTracings,
// As the userBoundingBoxes should have multiple sizes of the smallest one, a check with a warning should be included.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// As the userBoundingBoxes should have multiple sizes of the smallest one, a check with a warning should be included.
// As the userBoundingBoxes should have extents that are multiples of the smallest extent, a check with a warning should be included.

// As training on fallback data is supported and an annotation is not required to have VolumeTracings,
// it is necessary to save userBoundingBoxes separately and not load them from volumeTracings entries to support skeleton only annotations.
// Moreover, in case an annotations is a task, its task bounding box should also be used for training.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Moreover, in case an annotations is a task, its task bounding box should also be used for training.
// Moreover, in case an annotation is a task, its task bounding box should also be used for training.

// Note that a copy of the userBoundingBoxes is included in each volume and skeleton tracing of an annotation. Thus, it doesn't matter from which the userBoundingBoxes are taken.
export type AnnotationInfoForAIJob<GenericAnnotation> = {
export type AnnotationInfoForAITrainingJob<GenericAnnotation> = {
annotation: GenericAnnotation;
dataset: APIDataset;
volumeTracings: VolumeTracing[];
Expand Down Expand Up @@ -175,8 +178,8 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
getMagsForSegmentationLayer: (annotationId: string, layerName: string) => MagInfo;
onClose: () => void;
ensureSavedState?: (() => Promise<void>) | null;
annotationInfos: Array<AnnotationInfoForAIJob<GenericAnnotation>>;
onAddAnnotationsInfos?: (newItems: Array<AnnotationInfoForAIJob<APIAnnotation>>) => void;
annotationInfos: Array<AnnotationInfoForAITrainingJob<GenericAnnotation>>;
onAddAnnotationsInfos?: (newItems: Array<AnnotationInfoForAITrainingJob<APIAnnotation>>) => void;
}) {
const [form] = Form.useForm();

Expand Down Expand Up @@ -205,6 +208,10 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
const magInfoForLayer: Array<MagInfo> = Form.useWatch(() => {
return watcherFunctionRef.current();
}, form);
const trainingAnnotationsInfo = Form.useWatch("trainingAnnotations", form) as Array<{
annotationId: string;
mag: Vector3;
}>;

const [useCustomWorkflow, setUseCustomWorkflow] = React.useState(false);

Expand All @@ -220,7 +227,6 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
annotationId,
groundTruthLayerName,
).getMagList();
console.log("getintersectingmaglist", dataLayerMags, groundTruthLayerMags);

return groundTruthLayerMags?.filter((groundTruthMag) =>
dataLayerMags?.find((mag) => V3.equals(mag, groundTruthMag)),
Expand Down Expand Up @@ -275,12 +281,16 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
modelCategory: AiModelCategory.EM_NEURONS,
};

const userBoundingBoxes = annotationInfos.flatMap(({ userBoundingBoxes, annotation }) =>
userBoundingBoxes.map((box) => ({
const userBoundingBoxes = annotationInfos.flatMap(({ userBoundingBoxes, annotation }) => {
const annotationId = "id" in annotation ? annotation.id : annotation.annotationId;
return userBoundingBoxes.map((box) => ({
...box,
annotationId: "id" in annotation ? annotation.id : annotation.annotationId,
})),
);
annotationId: annotationId,
trainingMag: trainingAnnotationsInfo?.find(
(formInfo) => formInfo.annotationId === annotationId,
)?.mag,
}));
});

const bboxesVoxelCount = _.sum(
(userBoundingBoxes || []).map((bbox) => new BoundingBox(bbox.boundingBox).getVolume()),
Expand Down Expand Up @@ -495,7 +505,7 @@ export function CollapsibleWorkflowYamlEditor({
}

function checkAnnotationsForErrorsAndWarnings<T extends HybridTracing | APIAnnotation>(
annotationsWithDatasets: Array<AnnotationInfoForAIJob<T>>,
annotationsWithDatasets: Array<AnnotationInfoForAITrainingJob<T>>,
): {
hasAnnotationErrors: boolean;
errors: string[];
Expand Down Expand Up @@ -525,8 +535,12 @@ function checkAnnotationsForErrorsAndWarnings<T extends HybridTracing | APIAnnot
return { hasAnnotationErrors: false, errors: [] };
}

const MIN_BBOX_EXTENT_IN_EACH_DIM = 32;
function checkBoundingBoxesForErrorsAndWarnings(
userBoundingBoxes: (UserBoundingBox & { annotationId: string })[],
userBoundingBoxes: (UserBoundingBox & {
annotationId: string;
trainingMag: Vector3 | undefined;
})[],
): {
hasBBoxErrors: boolean;
hasBBoxWarnings: boolean;
Expand All @@ -543,22 +557,51 @@ function checkBoundingBoxesForErrorsAndWarnings(
}
// Find smallest bounding box dimensions
const minDimensions = userBoundingBoxes.reduce(
(min, { boundingBox: box }) => ({
x: Math.min(min.x, box.max[0] - box.min[0]),
y: Math.min(min.y, box.max[1] - box.min[1]),
z: Math.min(min.z, box.max[2] - box.min[2]),
}),
(min, { boundingBox: box, trainingMag }) => {
let bbox = new BoundingBox(box);
if (trainingMag) {
bbox = bbox.alignWithMag(trainingMag, "shrink");
}
const size = bbox.getSize();
return {
x: Math.min(min.x, size[0]),
y: Math.min(min.y, size[1]),
z: Math.min(min.z, size[2]),
};
},
{ x: Number.POSITIVE_INFINITY, y: Number.POSITIVE_INFINITY, z: Number.POSITIVE_INFINITY },
);

// Validate minimum size and multiple requirements
type BoundingBoxWithAnnotationId = { boundingBox: Vector6; name: string; annotationId: string };
const tooSmallBoxes: BoundingBoxWithAnnotationId[] = [];
const nonMultipleBoxes: BoundingBoxWithAnnotationId[] = [];
userBoundingBoxes.forEach(({ boundingBox: box, name, annotationId }) => {
const arrayBox = computeArrayFromBoundingBox(box);
const notMagAlignedBoundingBoxes: (BoundingBoxWithAnnotationId & {
alignedBoundingBox: Vector6;
})[] = [];
userBoundingBoxes.forEach(({ boundingBox: box, name, annotationId, trainingMag }) => {
const boundingBox = new BoundingBox(box);
let arrayBox = computeArrayFromBoundingBox(box);
if (trainingMag) {
const alignedBoundingBox = boundingBox.alignWithMag(trainingMag, "shrink");
if (!alignedBoundingBox.equals(boundingBox)) {
const alignedArrayBox = computeArrayFromBoundingBox(alignedBoundingBox);
notMagAlignedBoundingBoxes.push({
boundingBox: arrayBox,
name,
annotationId,
alignedBoundingBox: alignedArrayBox,
});
// Update the arrayBox as the aligned version of the bounding box will be used for training.
arrayBox = alignedArrayBox;
}
}
const [_x, _y, _z, width, height, depth] = arrayBox;
if (width < 10 || height < 10 || depth < 10) {
if (
width < MIN_BBOX_EXTENT_IN_EACH_DIM ||
height < MIN_BBOX_EXTENT_IN_EACH_DIM ||
depth < MIN_BBOX_EXTENT_IN_EACH_DIM
) {
tooSmallBoxes.push({ boundingBox: arrayBox, name, annotationId });
}

Expand All @@ -571,14 +614,25 @@ function checkBoundingBoxesForErrorsAndWarnings(
}
});

if (notMagAlignedBoundingBoxes.length > 0) {
hasBBoxWarnings = true;
const notMagAlignedBoundingBoxesStrings = notMagAlignedBoundingBoxes.map(
({ boundingBox, name, annotationId, alignedBoundingBox }) =>
`'${name}' of annotation ${annotationId}: ${boundingBox.join(", ")} -> ${alignedBoundingBox.join(", ")}`,
);
warnings.push(
`The following bounding boxes are not aligned with the selected magnification. They will be automatically shrunk to be aligned with the magnification:\n${notMagAlignedBoundingBoxesStrings.join("\n")}`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be helpful to include the actual magnification here, for example at the end of the first sentence.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the problem is that one can select different mags for different annotations. Therefore, I added the mag to the end of each bbox entry in the warning message.
But the problem is, that this makes the text even longer. I don't know if that's a benefit.
Maybe a better tree structure of the error message would help here 🤔. But this would also complicate the the code a little.

Here is what I currently have:
image

And this is what I mean with an improved hierarchy structure:

The following bounding boxes are not aligned with the selected magnification. They will be automatically shrunk to be aligned with the magnification:
- Annotation 677e8b3f5f0100a80b1dc6ae
  - 'Bounding box 1' (3584, 3584, 1024, 64, 64, 64) will be 1792, 1792, 1024, 32, 32, 64 in mag 2, 2, 1
  - 'Bounding box 2' (3648, 3648, 1088, 64, 64, 64) will be 1824, 1824, 1088, 32, 32, 64 in mag 2, 2, 1
  - 'Bounding box 3' (3712, 3712, 1152, 64, 64, 64) will be 1856, 1856, 1152, 32, 32, 64 in mag 2, 2, 1 

What do you think about this tree structured text?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that a lot 👍

Copy link
Contributor Author

@MichaelBuessemeyer MichaelBuessemeyer Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I made it work 🎉
image

Could you please check the newest changes?

);
}

const boxWithIdToString = ({ boundingBox, name, annotationId }: BoundingBoxWithAnnotationId) =>
`'${name}' of annotation ${annotationId}: ${boundingBox.join(", ")}`;

if (tooSmallBoxes.length > 0) {
hasBBoxWarnings = true;
const tooSmallBoxesStrings = tooSmallBoxes.map(boxWithIdToString);
warnings.push(
`The following bounding boxes are not at least 10 Vx in each dimension which is suboptimal for the training:\n${tooSmallBoxesStrings.join("\n")}`,
`The following bounding boxes are not at least ${MIN_BBOX_EXTENT_IN_EACH_DIM} Vx in each dimension which is suboptimal for the training:\n${tooSmallBoxesStrings.join("\n")}`,
);
}

Expand All @@ -596,31 +650,44 @@ function checkBoundingBoxesForErrorsAndWarnings(
function AnnotationsCsvInput({
onAdd,
}: {
onAdd: (newItems: Array<AnnotationInfoForAIJob<APIAnnotation>>) => void;
onAdd: (newItems: Array<AnnotationInfoForAITrainingJob<APIAnnotation>>) => void;
}) {
const [value, setValue] = useState("");
const onClickAdd = async () => {
const newItems = [];
const annotationIdsForTraining = [];
const unfinishedTasks = [];

const lines = value
.split("\n")
.map((line) => line.trim())
.filter((line) => line !== "");
for (const annotationUrlOrId of lines) {
if (annotationUrlOrId.includes("/")) {
newItems.push({
annotationId: annotationUrlOrId.split("/").at(-1) as string,
});
for (const taskOrAnnotationIdOrUrl of lines) {
if (taskOrAnnotationIdOrUrl.includes("/")) {
annotationIdsForTraining.push(taskOrAnnotationIdOrUrl.split("/").at(-1) as string);
} else {
newItems.push({
annotationId: annotationUrlOrId,
});
let isTask = true;
try {
const annotations = await getAnnotationsForTask(taskOrAnnotationIdOrUrl, {
showErrorToast: false,
});
const finishedAnnotations = annotations.filter(({ state }) => state === "Finished");
if (annotations.length > 0) {
annotationIdsForTraining.push(...finishedAnnotations.map(({ id }) => id));
} else {
unfinishedTasks.push(taskOrAnnotationIdOrUrl);
}
} catch (_e) {
isTask = false;
}
if (!isTask) {
annotationIdsForTraining.push(taskOrAnnotationIdOrUrl);
}
}
}

const newAnnotationsWithDatasets = await Promise.all(
newItems.map(async (item) => {
const annotation = await getAnnotationInformation(item.annotationId);
annotationIdsForTraining.map(async (annotationId) => {
const annotation = await getAnnotationInformation(annotationId);
const dataset = await getDataset(annotation.datasetId);

const volumeServerTracings: ServerVolumeTracing[] = await Promise.all(
Expand Down Expand Up @@ -651,6 +718,16 @@ function AnnotationsCsvInput({
);
}
}
if (annotation.task?.boundingBox) {
const largestId = Math.max(...userBoundingBoxes.map(({ id }) => id));
userBoundingBoxes.push({
name: "Task Bounding Box",
boundingBox: Utils.computeBoundingBoxFromBoundingBoxObject(annotation.task.boundingBox),
color: [0, 0, 0],
isVisible: true,
id: largestId + 1,
});
}

return {
annotation,
Expand All @@ -663,14 +740,18 @@ function AnnotationsCsvInput({
};
}),
);

if (unfinishedTasks.length > 0) {
Toast.warning(
`The following tasks have no finished annotations: ${unfinishedTasks.join(", ")}`,
);
}
onAdd(newAnnotationsWithDatasets);
};
return (
<div>
<FormItem
name="annotationCsv"
label="Annotations CSV"
label="Annotations or Tasks CSV"
hasFeedback
initialValue={value}
rules={[
Expand All @@ -693,7 +774,7 @@ function AnnotationsCsvInput({
>
<TextArea
className="input-monospace"
placeholder="annotationUrlOrId"
placeholder="taskOrAnnotationIdOrUrl"
autoSize={{
minRows: 6,
}}
Expand Down