Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Choose mag when training models on multiple annotations #8266

Merged
merged 25 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9c61a80
very much WIP: add form item to select mag per annotation
knollengewaechs Dec 6, 2024
6305026
pass mags to form
knollengewaechs Dec 9, 2024
19a1666
pass mags through to parent form
knollengewaechs Dec 9, 2024
f90cacc
Merge branch 'master' into multi-anno-trainings-choose-mag
knollengewaechs Dec 12, 2024
5d3151f
remove default value
knollengewaechs Dec 12, 2024
b7ce334
WIP: clear mag selection when new layer is selected
knollengewaechs Dec 12, 2024
5fc98f0
WIP: update mags
knollengewaechs Dec 13, 2024
9011674
select intersection of data and segmentation layer mags
knollengewaechs Dec 16, 2024
48876d9
load mags if both layers are already selected
knollengewaechs Dec 16, 2024
7fdf1e1
adjust line height for non required fields
knollengewaechs Dec 16, 2024
da1ba84
Merge branch 'master' into multi-anno-trainings-choose-mag
knollengewaechs Dec 16, 2024
0dea8e5
add changelog
knollengewaechs Dec 16, 2024
3f4bcaf
remove dev edits
knollengewaechs Dec 16, 2024
420f648
WIP: address review
knollengewaechs Dec 20, 2024
3da8f49
WIP: add Form.useWatch to update value whenever form changes
knollengewaechs Dec 20, 2024
88581c6
fix multi mag form
knollengewaechs Dec 20, 2024
0f83dba
use useRef for watcherFunction to get the current annotation infos
knollengewaechs Dec 23, 2024
0e692b4
Merge branch 'master' into multi-anno-trainings-choose-mag
knollengewaechs Dec 23, 2024
6325828
add await to async method
knollengewaechs Jan 3, 2025
0aee521
WIP: render volume annotation layer names more nicely in layer selection
knollengewaechs Jan 3, 2025
78af65e
fix that tracing ids were shown (and sent to worker) even though huma…
philippotto Jan 6, 2025
feb8a45
remove application.conf edit
knollengewaechs Jan 6, 2025
0ca90f2
merge master
knollengewaechs Jan 6, 2025
e12ab6f
lint
knollengewaechs Jan 6, 2025
76b1fa6
Merge branch 'master' into multi-anno-trainings-choose-mag
knollengewaechs Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ For upgrade instructions, please check the [migration guide](MIGRATIONS.released
### Added
- Added the total volume of a dataset to a tooltip in the dataset info tab. [#8229](https://github.com/scalableminds/webknossos/pull/8229)
- Optimized performance of data loading with “fill value“ chunks. [#8271](https://github.com/scalableminds/webknossos/pull/8271)
- It is now possible to select the magnification of the layers on which an AI model will be trained. [#8266](https://github.com/scalableminds/webknossos/pull/8266)

### Changed
- Renamed "resolution" to "magnification" in more places within the codebase, including local variables. [#8168](https://github.com/scalableminds/webknossos/pull/8168)
Expand Down
4 changes: 2 additions & 2 deletions conf/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ features {
taskReopenAllowedInSeconds = 30
allowDeleteDatasets = true
# to enable jobs for local development, use "yarn enable-jobs" to also activate it in the database
jobsEnabled = false
voxelyticsEnabled = false
jobsEnabled = true
voxelyticsEnabled = true
knollengewaechs marked this conversation as resolved.
Show resolved Hide resolved
# For new users, the dashboard will show a banner which encourages the user to check out the following dataset.
# If isWkorgInstance == true, `/createExplorative/hybrid/true` is appended to the URL so that a new tracing is opened.
# If isWkorgInstance == false, `/view` is appended to the URL so that it's opened in view mode (since the user might not
Expand Down
8 changes: 4 additions & 4 deletions frontend/javascripts/admin/voxelytics/ai_model_list_view.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function TrainNewAiJobModal({ onClose }: { onClose: () => void }) {
AnnotationInfoForAIJob<APIAnnotation>[]
>([]);

const getMagForSegmentationLayer = async (annotationId: string, layerName: string) => {
const getMagsForSegmentationLayer = (annotationId: string, layerName: string) => {
// The layer name is a human-readable one. It can either belong to an annotationLayer
// (therefore, also to a volume tracing) or to the actual dataset.
// Both are checked below. This won't be ambiguous because annotationLayers must not
Expand All @@ -130,10 +130,10 @@ function TrainNewAiJobModal({ onClose }: { onClose: () => void }) {
(tracing) => tracing.tracingId === annotationLayer.tracingId,
);
const mags = volumeTracingMags[volumeTracingIndex] || ([[1, 1, 1]] as Vector3[]);
return getMagInfo(mags).getFinestMag();
return getMagInfo(mags);
} else {
const segmentationLayer = getSegmentationLayerByName(dataset, layerName);
return getMagInfo(segmentationLayer.resolutions).getFinestMag();
return getMagInfo(segmentationLayer.resolutions);
}
};

Expand All @@ -152,7 +152,7 @@ function TrainNewAiJobModal({ onClose }: { onClose: () => void }) {
maskClosable={false}
>
<TrainAiModelTab
getMagForSegmentationLayer={getMagForSegmentationLayer}
getMagsForSegmentationLayer={getMagsForSegmentationLayer}
onClose={onClose}
annotationInfos={annotationInfosForAiJob}
onAddAnnotationsInfos={(newItems) => {
Expand Down
3 changes: 3 additions & 0 deletions frontend/javascripts/components/layer_selection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type LayerSelectionProps<L extends { name: string }> = {
getReadableNameForLayer: (layer: L) => string;
fixedLayerName?: string;
label?: string;
onChange?: (a: string) => void;
};

export function LayerSelection<L extends { name: string }>({
Expand Down Expand Up @@ -65,6 +66,7 @@ export function LayerSelectionFormItem<L extends { name: string }>({
getReadableNameForLayer,
fixedLayerName,
label,
onChange,
}: LayerSelectionProps<L>): JSX.Element {
const layerType = chooseSegmentationLayer ? "segmentation" : "color";
return (
Expand All @@ -85,6 +87,7 @@ export function LayerSelectionFormItem<L extends { name: string }>({
fixedLayerName={fixedLayerName}
layerType={layerType}
getReadableNameForLayer={getReadableNameForLayer}
onChange={onChange}
/>
</Form.Item>
);
Expand Down
73 changes: 73 additions & 0 deletions frontend/javascripts/components/mag_selection.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import { Form, Select } from "antd";
import { V3 } from "libs/mjs";
import { clamp } from "libs/utils";
import type { Vector3 } from "oxalis/constants";
import type { MagInfo } from "oxalis/model/helpers/mag_info";

export function MagSelectionFormItem({
name,
magInfo,
}: {
name: string | Array<string | number>;
magInfo: MagInfo | undefined;
}): JSX.Element {
return (
<Form.Item
name={name}
label={"Magnification"}
rules={[
{
required: true,
message: "Please select the magnification.",
},
]}
>
<MagSelection magInfo={magInfo} />
</Form.Item>
);
}

function MagSelection({
magInfo,
value,
onChange,
}: {
magInfo: MagInfo | undefined;
value?: Vector3;
onChange?: (newValue: Vector3) => void;
}): JSX.Element {
const allMags = magInfo != null ? magInfo.getMagList() : [];

const onSelect = (index: number | undefined) => {
if (onChange == null || index == null) return;
const newMag = allMags[index];
if (newMag != null) onChange(newMag);
};

return (
<Select
placeholder="Select a magnification"
value={
// Using the index of the mag *in the mag list* as value internally,
// this is different from the mag index.
value == null || magInfo == null
? null
: clamp(
0,
allMags.findIndex((v) => V3.equals(v, value)),
allMags.length - 1,
)
}
onSelect={onSelect}
>
{allMags.map((mag, index) => {
const readableName = mag.join("-");
return (
<Select.Option key={index} value={index}>
{readableName}
</Select.Option>
);
})}
</Select>
);
}
133 changes: 102 additions & 31 deletions frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useState } from "react";
import React, { useRef, useState } from "react";
import {
Alert,
Form,
Expand Down Expand Up @@ -34,11 +34,19 @@ import _ from "lodash";
import BoundingBox from "oxalis/model/bucket_data_handling/bounding_box";
import { formatVoxels } from "libs/format_utils";
import * as Utils from "libs/utils";
import type { APIAnnotation, APIDataset, ServerVolumeTracing } from "types/api_flow_types";
import type {
APIAnnotation,
APIDataLayer,
APIDataset,
ServerVolumeTracing,
} from "types/api_flow_types";
import type { Vector3, Vector6 } from "oxalis/constants";
import { serverVolumeToClientVolumeTracing } from "oxalis/model/reducers/volumetracing_reducer";
import { convertUserBoundingBoxesFromServerToFrontend } from "oxalis/model/reducers/reducer_helpers";
import { computeArrayFromBoundingBox } from "libs/utils";
import { MagSelectionFormItem } from "components/mag_selection";
import { MagInfo } from "oxalis/model/helpers/mag_info";
import { V3 } from "libs/mjs";

const { TextArea } = Input;
const FormItem = Form.Item;
Expand Down Expand Up @@ -126,15 +134,15 @@ export function TrainAiModelFromAnnotationTab({ onClose }: { onClose: () => void
const tracing = useSelector((state: OxalisState) => state.tracing);
const dataset = useSelector((state: OxalisState) => state.dataset);

const getMagForSegmentationLayer = async (_annotationId: string, layerName: string) => {
const getMagsForSegmentationLayer = (_annotationId: string, layerName: string) => {
const segmentationLayer = getSegmentationLayerByHumanReadableName(dataset, tracing, layerName);
return getMagInfo(segmentationLayer.resolutions).getFinestMag();
return getMagInfo(segmentationLayer.resolutions);
knollengewaechs marked this conversation as resolved.
Show resolved Hide resolved
};
const userBoundingBoxes = getSomeTracing(tracing).userBoundingBoxes;

return (
<TrainAiModelTab
getMagForSegmentationLayer={getMagForSegmentationLayer}
getMagsForSegmentationLayer={getMagsForSegmentationLayer}
ensureSavedState={() => Model.ensureSavedState()}
onClose={onClose}
annotationInfos={[
Expand All @@ -150,42 +158,91 @@ export function TrainAiModelFromAnnotationTab({ onClose }: { onClose: () => void
);
}

type TrainingAnnotation = {
annotationId: string;
imageDataLayer: string;
layerName: string;
mag: Vector3;
};

export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | HybridTracing>({
getMagForSegmentationLayer,
getMagsForSegmentationLayer,
onClose,
ensureSavedState,
annotationInfos,
onAddAnnotationsInfos,
}: {
getMagForSegmentationLayer: (annotationId: string, layerName: string) => Promise<Vector3>;
getMagsForSegmentationLayer: (annotationId: string, layerName: string) => MagInfo;
onClose: () => void;
ensureSavedState?: (() => Promise<void>) | null;
annotationInfos: Array<AnnotationInfoForAIJob<GenericAnnotation>>;
onAddAnnotationsInfos?: (newItems: Array<AnnotationInfoForAIJob<APIAnnotation>>) => void;
}) {
const [form] = Form.useForm();

const watcherFunctionRef = useRef(() => {
return [new MagInfo([])];
});
watcherFunctionRef.current = () => {
const getIntersectingMags = (idx: number, annotationId: string, dataset: APIDataset) => {
const segmentationLayerName = form.getFieldValue(["trainingAnnotations", idx, "layerName"]);
const imageDataLayerName = form.getFieldValue(["trainingAnnotations", idx, "imageDataLayer"]);
if (segmentationLayerName != null && imageDataLayerName != null) {
return new MagInfo(
getIntersectingMagList(annotationId, dataset, segmentationLayerName, imageDataLayerName),
);
}
return new MagInfo([]);
};

return annotationInfos.map((annotationInfo, idx: number) => {
const annotation = annotationInfo.annotation;
const annotationId = "id" in annotation ? annotation.id : annotation.annotationId;
return getIntersectingMags(idx, annotationId, annotationInfo.dataset);
});
};

const magInfoForLayer: Array<MagInfo> = Form.useWatch(() => {
return watcherFunctionRef.current();
}, form);

const [useCustomWorkflow, setUseCustomWorkflow] = React.useState(false);

const getTrainingAnnotations = async (values: any) => {
return Promise.all(
values.trainingAnnotations.map(
async (trainingAnnotation: {
annotationId: string;
imageDataLayer: string;
layerName: string;
}) => {
const { annotationId, imageDataLayer, layerName } = trainingAnnotation;
return {
annotationId,
colorLayerName: imageDataLayer,
segmentationLayerName: layerName,
mag: await getMagForSegmentationLayer(annotationId, layerName),
};
},
),
const getIntersectingMagList = (
annotationId: string,
dataset: APIDataset,
groundTruthLayerName: string,
imageDataLayerName: string,
) => {
const colorLayers = getColorLayers(dataset);
const dataLayerMags = getMagsForColorLayer(colorLayers, imageDataLayerName);
const groundTruthLayerMags = getMagsForSegmentationLayer(
annotationId,
groundTruthLayerName,
).getMagList();

return groundTruthLayerMags?.filter((groundTruthMag) =>
dataLayerMags?.find((mag) => V3.equals(mag, groundTruthMag)),
);
};

const getMagsForColorLayer = (colorLayers: APIDataLayer[], layerName: string) => {
const colorLayer = colorLayers.find((layer) => layer.name === layerName);
return colorLayer != null ? getMagInfo(colorLayer.resolutions).getMagList() : null;
};
knollengewaechs marked this conversation as resolved.
Show resolved Hide resolved

const getTrainingAnnotations = (values: any) => {
return values.trainingAnnotations.map((trainingAnnotation: TrainingAnnotation) => {
const { annotationId, imageDataLayer, layerName, mag } = trainingAnnotation;
return {
annotationId,
colorLayerName: imageDataLayer,
segmentationLayerName: layerName,
mag,
};
});
};

const onFinish = async (form: FormInstance<any>, useCustomWorkflow: boolean, values: any) => {
form.validateFields();

Expand All @@ -194,8 +251,8 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
await ensureSavedState();
}

await runTraining({
trainingAnnotations: await getTrainingAnnotations(values),
philippotto marked this conversation as resolved.
Show resolved Hide resolved
runTraining({
trainingAnnotations: getTrainingAnnotations(values),
name: values.modelName,
aiModelCategory: values.modelCategory,
workflowYaml: useCustomWorkflow ? values.workflowYaml : undefined,
Expand Down Expand Up @@ -240,7 +297,6 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
const hasWarnings = hasBBoxWarnings;
const errors = [...annotationErrors, ...bboxErrors];
const warnings = bboxWarnings;

return (
<Form
onFinish={(values) => onFinish(form, useCustomWorkflow, values)}
Expand Down Expand Up @@ -281,19 +337,24 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
);
const fixedSelectedColorLayer = colorLayers.length === 1 ? colorLayers[0] : null;
const annotationId = "id" in annotation ? annotation.id : annotation.annotationId;

const onChangeLayer = () => {
form.setFieldValue(["trainingAnnotations", idx, "mag"], undefined);
};

return (
<Row key={annotationId} gutter={8}>
<Col span={8}>
<Col span={6}>
<FormItem
hasFeedback
name={["trainingAnnotations", idx, "annotationId"]}
label="Annotation ID"
label={<div style={{ minHeight: 24 }}>Annotation ID</div>} // balance height with labels of required fields
initialValue={annotationId}
>
<Input disabled />
</FormItem>
</Col>
<Col span={8}>
<Col span={6}>
<FormItem
hasFeedback
name={["trainingAnnotations", idx, "imageDataLayer"]}
Expand All @@ -311,19 +372,29 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
getReadableNameForLayer={(layer) => layer.name}
fixedLayerName={fixedSelectedColorLayer?.name || undefined}
style={{ width: "100%" }}
onChange={onChangeLayer}
/>
</FormItem>
</Col>
<Col span={8}>
<Col span={6}>
<LayerSelectionFormItem
name={["trainingAnnotations", idx, "layerName"]}
chooseSegmentationLayer
layers={segmentationLayers}
getReadableNameForLayer={(layer) => {
return layer.name;
//TODO_c fix that fallback layers are shown at least with the correct name?
// eg. name (active layer)
}}
fixedLayerName={fixedSelectedSegmentationLayer?.name || undefined}
label="Ground Truth Layer"
onChange={onChangeLayer}
/>
</Col>
<Col span={6}>
<MagSelectionFormItem
name={["trainingAnnotations", idx, "mag"]}
magInfo={magInfoForLayer != null ? magInfoForLayer[idx] : new MagInfo([])}
/>
</Col>
</Row>
Expand Down