Skip to content
This repository has been archived by the owner on Jan 3, 2024. It is now read-only.

Commit

Permalink
Merge pull request #70 from dstansby/use-default-model
Browse files Browse the repository at this point in the history
Add checkbox to use pre-trained weights
  • Loading branch information
dstansby authored Mar 8, 2022
2 parents fee16f4 + 3ae901b commit f1e1baf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
10 changes: 8 additions & 2 deletions cellfinder_napari/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def widget(
max_cluster_size: int,
classification_options,
trained_model: Optional[Path],
use_pre_trained_weights: bool,
misc_options,
start_plane: int,
end_plane: int,
Expand Down Expand Up @@ -104,6 +105,8 @@ def widget(
max_cluster_size : int
Largest putative cell cluster (in cubic um) where splitting
should be attempted
use_pre_trained_weights : bool
Select to use pre-trained model weights
trained_model : Optional[Path]
Trained model file path (home directory (default) -> pretrained weights)
start_plane : int
Expand Down Expand Up @@ -138,8 +141,11 @@ def widget(
max_cluster_size,
)

trained_model = None if trained_model == Path.home() else trained_model
classification_inputs = ClassificationInputs(trained_model)
if use_pre_trained_weights:
trained_model = None
classification_inputs = ClassificationInputs(
use_pre_trained_weights, trained_model
)

end_plane = len(signal_image.data) if end_plane == 0 else end_plane

Expand Down
8 changes: 7 additions & 1 deletion cellfinder_napari/input_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,21 @@ def widget_representation(cls) -> dict:
class ClassificationInputs(InputContainer):
"""Container for classification inputs."""

use_pre_trained_weights: bool = True
trained_model: Optional[Path] = Path.home()

def as_core_arguments(self) -> dict:
return super().as_core_arguments()
args = super().as_core_arguments()
del args["use_pre_trained_weights"]
return args

@classmethod
def widget_representation(cls) -> dict:
return dict(
classification_options=html_label_widget("Classification:"),
use_pre_trained_weights=dict(
value=cls.defaults()["use_pre_trained_weights"]
),
trained_model=dict(value=cls.defaults()["trained_model"]),
)

Expand Down

0 comments on commit f1e1baf

Please sign in to comment.