Skip to content

Commit

Permalink
fix a sliced prediction bug for small images (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon authored Feb 26, 2021
1 parent 79c1f00 commit 8ba6da2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
2 changes: 1 addition & 1 deletion sahi/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.3.1"
53 changes: 35 additions & 18 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,28 +192,45 @@ def get_sliced_prediction(
# create prediction input
num_group = int(num_slices / num_batch)
if verbose == 1 or verbose == 2:
print("Number of slices:", num_slices)
if num_slices > 0:
print("Number of slices:", num_slices)
else:
print("Number of slices:", 1)
object_prediction_list = []
for group_ind in range(num_group):
# prepare batch (currently supports only 1 batch)
image_list = []
shift_amount_list = []
for image_ind in range(num_batch):
image_list.append(
slice_image_result.images[group_ind * num_batch + image_ind]
)
shift_amount_list.append(
slice_image_result.starting_pixels[group_ind * num_batch + image_ind]
if num_slices > 0: # if zero_frac < max_allowed_zeros_ratio from slice_image
for group_ind in range(num_group):
# prepare batch (currently supports only 1 batch)
image_list = []
shift_amount_list = []
for image_ind in range(num_batch):
image_list.append(
slice_image_result.images[group_ind * num_batch + image_ind]
)
shift_amount_list.append(
slice_image_result.starting_pixels[
group_ind * num_batch + image_ind
]
)
# perform batch prediction
prediction_result = get_prediction(
image=image_list[0],
detection_model=detection_model,
shift_amount=shift_amount_list[0],
full_shape=[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
)
# perform batch prediction
object_prediction_list.extend(prediction_result["object_prediction_list"])
else: # if zero_frac >= max_allowed_zeros_ratio from slice_image
prediction_result = get_prediction(
image=image_list[0],
image=image,
detection_model=detection_model,
shift_amount=shift_amount_list[0],
full_shape=[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
shift_amount=[0, 0],
full_shape=None,
merger=None,
matcher=None,
verbose=0,
)
object_prediction_list.extend(prediction_result["object_prediction_list"])

Expand Down

0 comments on commit 8ba6da2

Please sign in to comment.