Skip to content

Commit

Permalink
Restore fixed batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
GoldenAnpu committed Jan 9, 2025
1 parent 4c3a249 commit c929cf9
Showing 1 changed file with 26 additions and 46 deletions.
72 changes: 26 additions & 46 deletions train/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,29 +2130,19 @@ def get_image_infos_by_split(split: list):
if "/" in dataset_name:
dataset_name = dataset_name.split("/")[-1]
ds_info = ds_infos_dict[dataset_name]
batch_size = 500
min_batch_size = 50
while batch_size >= min_batch_size:
try:
for batched_names in sly.batched(image_names, batch_size):
batch = api.image.get_list(
ds_info.id,
filters=[
{
"field": "name",
"operator": "in",
"value": batched_names,
}
],
force_metadata_for_links=False,
)
image_infos.extend(batch)
break
except Exception as e:
sly.logger.warning(f"Error occurred: {e}. Reducing batch size for filter to {batch_size // 2}")
batch_size //= 2
if batch_size < min_batch_size:
raise RuntimeError(f"Batch size for filter in listing images reduced to {batch_size} and still not working. Aborting.")
for batched_names in sly.batched(image_names, 200):
batch = api.image.get_list(
ds_info.id,
filters=[
{
"field": "name",
"operator": "in",
"value": batched_names,
}
],
force_metadata_for_links=False,
)
image_infos.extend(batch)
return image_infos

val_image_infos = get_image_infos_by_split(val_set)
Expand Down Expand Up @@ -3188,29 +3178,19 @@ def get_image_infos_by_split(split: list):
if "/" in dataset_name:
dataset_name = dataset_name.split("/")[-1]
ds_info = ds_infos_dict[dataset_name]
batch_size = 500
min_batch_size = 50
while batch_size >= min_batch_size:
try:
for batched_names in sly.batched(image_names, batch_size):
batch = api.image.get_list(
ds_info.id,
filters=[
{
"field": "name",
"operator": "in",
"value": batched_names,
}
],
force_metadata_for_links=False,
)
image_infos.extend(batch)
break
except Exception as e:
sly.logger.warning(f"Error occurred: {e}. Reducing batch size for filter to {batch_size // 2}")
batch_size //= 2
if batch_size < min_batch_size:
raise RuntimeError(f"Batch size for filter in listing images reduced to {batch_size} and still not working. Aborting.")
for batched_names in sly.batched(image_names, 200):
batch = api.image.get_list(
ds_info.id,
filters=[
{
"field": "name",
"operator": "in",
"value": batched_names,
}
],
force_metadata_for_links=False,
)
image_infos.extend(batch)
return image_infos

val_image_infos = get_image_infos_by_split(val_set)
Expand Down

0 comments on commit c929cf9

Please sign in to comment.