Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set max instances for top down models #1070

Merged
merged 5 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 120 additions & 6 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):

"""Export a trained SLEAP model as a frozen graph. Initializes model,
Expand All @@ -470,10 +471,20 @@ def export_model(
sleap.nn.data.utils.describe_tensors as an example)
unrag_outputs: If `True` (default), any ragged tensors will be
converted to normal tensors and padded with NaNs

max_instances: If set, determines the max number of instances that a
multi-instance model returns. This is enforced during centroid
cropping and therefore only compatible with TopDown models.
"""

self._initialize_inference_model()
predictor_name = type(self).__name__

if max_instances is not None:
if "TopDown" in predictor_name:
print(f"\n max instances set, limiting instances to {max_instances} \n")
self.inference_model.centroid_crop.max_instances = max_instances
else:
raise Exception(f"{predictor_name} does not support max instance limit")

first_inference_layer = self.inference_model.layers[0]
keras_model_shape = first_inference_layer.keras_model.input.shape
Expand Down Expand Up @@ -1469,10 +1480,17 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):

super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
save_path,
signatures,
save_traces,
model_name,
tensors,
unrag_outputs,
max_instances,
)

self.confmap_config.save_json(os.path.join(save_path, "confmap_config.json"))
Expand Down Expand Up @@ -1523,6 +1541,8 @@ class CentroidCrop(InferenceLayer):
the predicted peaks. This is true by default since crops are used
for finding instance peaks in a top down model. If using a centroid
only inference model, this should be set to `False`.
max_instances: If set, determines the max number of instances that a
multi-instance model returns.
"""

def __init__(
Expand All @@ -1539,6 +1559,7 @@ def __init__(
confmaps_ind: Optional[int] = None,
offsets_ind: Optional[int] = None,
return_crops: bool = True,
max_instances: Optional[int] = None,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -1576,6 +1597,7 @@ def __init__(
self.integral_patch_size = integral_patch_size
self.return_confmaps = return_confmaps
self.return_crops = return_crops
self.max_instances = max_instances

@tf.function
def call(self, inputs):
Expand Down Expand Up @@ -1669,9 +1691,67 @@ def call(self, inputs):
# Store crop offsets.
crop_offsets = centroid_points - (self.crop_size / 2)

samples = tf.shape(imgs)[0]

n_peaks = tf.shape(centroid_points)[0]

if n_peaks > 0:

if self.max_instances is not None:

centroid_points = tf.RaggedTensor.from_value_rowids(
centroid_points, crop_sample_inds, nrows=samples
)
centroid_vals = tf.RaggedTensor.from_value_rowids(
centroid_vals, crop_sample_inds, nrows=samples
)

_centroid_vals = tf.TensorArray(
size=samples,
dtype=tf.float32,
infer_shape=False,
element_shape=[None],
)

_centroid_points = tf.TensorArray(
size=samples,
dtype=tf.float32,
infer_shape=False,
element_shape=[None, 2],
)

_row_ids = tf.TensorArray(
size=samples,
dtype=tf.int32,
infer_shape=False,
element_shape=[None],
)

for sample in range(samples):

top_points = tf.math.top_k(
centroid_vals[sample], k=self.max_instances
)
top_inds = top_points.indices

_centroid_vals = _centroid_vals.write(
sample, tf.gather(centroid_vals[sample], top_inds)
)

_centroid_points = _centroid_points.write(
sample, tf.gather(centroid_points[sample], top_inds)
)

_row_ids = _row_ids.write(sample, tf.fill([len(top_inds)], sample))

centroid_vals = _centroid_vals.concat()
centroid_points = _centroid_points.concat()
crop_sample_inds = _row_ids.concat()

n_peaks = tf.shape(crop_sample_inds)[0]

crop_offsets = centroid_points - (self.crop_size / 2)

# Crop instances around centroids.
bboxes = sleap.nn.data.instance_cropping.make_centered_bboxes(
centroid_points, self.crop_size, self.crop_size
Expand All @@ -1684,6 +1764,7 @@ def call(self, inputs):
crops = tf.reshape(
crops, [n_peaks, self.crop_size, self.crop_size, full_imgs.shape[3]]
)

else:
# No peaks found, so just create a placeholder stack.
crops = tf.zeros(
Expand All @@ -1692,7 +1773,6 @@ def call(self, inputs):
)

# Group crops by sample (samples, ?, ...).
samples = tf.shape(imgs)[0]
centroids = tf.RaggedTensor.from_value_rowids(
centroid_points, crop_sample_inds, nrows=samples
)
Expand Down Expand Up @@ -2390,10 +2470,17 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):

super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
save_path,
signatures,
save_traces,
model_name,
tensors,
unrag_outputs,
max_instances,
)

if self.confmap_config is not None:
Expand Down Expand Up @@ -4086,10 +4173,17 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):

super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
save_path,
signatures,
save_traces,
model_name,
tensors,
unrag_outputs,
max_instances,
)

if self.confmap_config is not None:
Expand Down Expand Up @@ -4215,6 +4309,7 @@ def export_model(
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
max_instances: Optional[int] = None,
):
"""High level export of a trained SLEAP model as a frozen graph.

Expand All @@ -4232,10 +4327,20 @@ def export_model(
sleap.nn.data.utils.describe_tensors as an example).
unrag_outputs: If `True` (default), any ragged tensors will be
converted to normal tensors and padded with NaNs
max_instances: If set, determines the max number of instances that a
multi-instance model returns. This is enforced during centroid
cropping and therefore only compatible with TopDown models.
"""
predictor = load_model(model_path)

predictor.export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
save_path,
signatures,
save_traces,
model_name,
tensors,
unrag_outputs,
max_instances,
)


Expand Down Expand Up @@ -4273,6 +4378,15 @@ def export_cli():
"Defaults to True."
),
)
parser.add_argument(
"-m",
"--max_instances",
type=int,
help=(
"Limit maximum number of instances in multi-instance models"
"Defaults to None"
),
)

args, _ = parser.parse_known_args()
export_model(args.models, args.export_path, unrag_outputs=args.unrag)
Expand Down
42 changes: 42 additions & 0 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
predictor = TopDownPredictor.from_trained_models(
centroid_model_path=min_centroid_model_path
)

predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels)
assert len(labels_pr) == 1
Expand All @@ -603,13 +604,21 @@ def test_topdown_predictor_centroid(min_labels, min_centroid_model_path):
inds1, inds2 = sleap.nn.utils.match_points(points_gt, points_pr)
assert_allclose(points_gt[inds1.numpy()], points_pr[inds2.numpy()], atol=1.5)

# test max_instances (>2 will fail)
predictor.inference_model.centroid_crop.max_instances = 2
labels_pr = predictor.predict(min_labels)

assert len(labels_pr) == 1
assert len(labels_pr[0].instances) == 2


def test_topdown_predictor_centered_instance(
min_labels, min_centered_instance_model_path
):
predictor = TopDownPredictor.from_trained_models(
confmap_model_path=min_centered_instance_model_path
)

predictor.verbosity = "none"
labels_pr = predictor.predict(min_labels)
assert len(labels_pr) == 1
Expand Down Expand Up @@ -859,6 +868,14 @@ def test_centroid_inference():
assert preds["centroids"].shape == (1, 3, 2)
assert preds["centroid_vals"].shape == (1, 3)

# test max instances (>3 will fail)
layer.max_instances = 3
out = layer(cms)

model = CentroidInferenceModel(layer)

preds = model.predict(cms)


def export_frozen_graph(model, preds, output_path):

Expand Down Expand Up @@ -1008,6 +1025,15 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm
unrag_outputs=False,
)

# max_instances should raise an exception for single instance
with pytest.raises(Exception):
export_model(
min_single_instance_robot_model_path,
save_path=tmp_path.as_posix(),
unrag_outputs=False,
max_instances=1,
)


def test_topdown_predictor_save(
min_centroid_model_path, min_centered_instance_model_path, tmp_path
Expand Down Expand Up @@ -1039,6 +1065,14 @@ def test_topdown_predictor_save(
unrag_outputs=False,
)

# test max instances
export_model(
[min_centroid_model_path, min_centered_instance_model_path],
save_path=tmp_path.as_posix(),
unrag_outputs=False,
max_instances=4,
)


def test_topdown_id_predictor_save(
min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path
Expand Down Expand Up @@ -1070,6 +1104,14 @@ def test_topdown_id_predictor_save(
unrag_outputs=False,
)

# test max instances
export_model(
[min_centroid_model_path, min_topdown_multiclass_model_path],
save_path=tmp_path.as_posix(),
unrag_outputs=False,
max_instances=4,
)


@pytest.mark.parametrize(
"output_path,tracker_method", [("not_default", "flow"), (None, "simple")]
Expand Down