diff --git a/src/main/java/org/ilastik/ilastik4ij/executors/PixelClassification.java b/src/main/java/org/ilastik/ilastik4ij/executors/PixelClassification.java index 252f6b46..a592fbed 100644 --- a/src/main/java/org/ilastik/ilastik4ij/executors/PixelClassification.java +++ b/src/main/java/org/ilastik/ilastik4ij/executors/PixelClassification.java @@ -20,8 +20,9 @@ public PixelClassification(File executableFilePath, File projectFileName, LogSer } public > ImgPlus classifyPixels(ImgPlus> rawInputImg, + ImgPlus> predictionMask, PixelPredictionType pixelPredictionType) throws IOException { - return executeIlastik(rawInputImg, null, pixelPredictionType); + return executeIlastik(rawInputImg, predictionMask, pixelPredictionType); } @Override @@ -36,6 +37,10 @@ else if (pixelPredictionType == PixelPredictionType.Probabilities) { commandLine.add("--raw_data=" + tempFiles.get(rawInputTempFile)); commandLine.add("--output_filename_format=" + tempFiles.get(outputTempFile)); + if(tempFiles.containsKey(secondInputTempFile)){ + commandLine.add("--prediction_mask=" + tempFiles.get(secondInputTempFile)); + } + return commandLine; } } diff --git a/src/main/java/org/ilastik/ilastik4ij/ui/IlastikPixelClassificationCommand.java b/src/main/java/org/ilastik/ilastik4ij/ui/IlastikPixelClassificationCommand.java index 02943d57..556ce0c0 100644 --- a/src/main/java/org/ilastik/ilastik4ij/ui/IlastikPixelClassificationCommand.java +++ b/src/main/java/org/ilastik/ilastik4ij/ui/IlastikPixelClassificationCommand.java @@ -7,6 +7,7 @@ import org.scijava.ItemIO; import org.scijava.app.StatusService; import org.scijava.command.Command; +import org.scijava.command.DynamicCommand; import org.scijava.log.LogService; import org.scijava.options.OptionsService; import org.scijava.plugin.Parameter; @@ -19,7 +20,7 @@ import static org.ilastik.ilastik4ij.executors.AbstractIlastikExecutor.PixelPredictionType; @Plugin(type = Command.class, headless = true, menuPath = "Plugins>ilastik>Run Pixel Classification Prediction") -public class IlastikPixelClassificationCommand implements Command { +public class IlastikPixelClassificationCommand extends DynamicCommand { @Parameter public LogService logService; @@ -42,6 +43,21 @@ public class IlastikPixelClassificationCommand implements Command { @Parameter(label = "Output type", choices = {UiConstants.PIXEL_PREDICTION_TYPE_PROBABILITIES, UiConstants.PIXEL_PREDICTION_TYPE_SEGMENTATION}, style = "radioButtonHorizontal") public String pixelClassificationType; + @Parameter(label = "Use Mask?", persist=false, initializer = "initUseMask") + public boolean useMask=false; + + protected void initUseMask(){ + useMask = false; + //resolveInput("useMask"); //this makes the input not be rendered -.- + } + + @Parameter( + label = "Prediction Mask", + required = false, + description = "An image with same dimensions as Raw Data, where the black pixels will be masked out of the predictions" + ) + public Dataset predictionMask; + @Parameter(type = ItemIO.OUTPUT) private ImgPlus> predictions; @@ -69,7 +85,11 @@ private void runClassification() throws IOException { projectFileName, logService, statusService, ilastikOptions.getNumThreads(), ilastikOptions.getMaxRamMb()); PixelPredictionType pixelPredictionType = PixelPredictionType.valueOf(pixelClassificationType); - this.predictions = pixelClassification.classifyPixels(inputImage.getImgPlus(), pixelPredictionType); + this.predictions = pixelClassification.classifyPixels( + inputImage.getImgPlus(), + useMask ? predictionMask.getImgPlus() : null, + pixelPredictionType + ); // DisplayUtils.showOutput(uiService, predictions); } diff --git a/src/test/java/org/ilastik/ilastik4ij/PixelClassificationDemo.java b/src/test/java/org/ilastik/ilastik4ij/PixelClassificationDemo.java index f742591f..e8e99462 100644 --- a/src/test/java/org/ilastik/ilastik4ij/PixelClassificationDemo.java +++ b/src/test/java/org/ilastik/ilastik4ij/PixelClassificationDemo.java @@ -62,7 +62,7 @@ public static & NativeType> void main(String[] args) t 1024 ); - final ImgPlus classifiedPixels = prediction.classifyPixels(inputDataset.getImgPlus(), PixelPredictionType.Probabilities); + final ImgPlus classifiedPixels = prediction.classifyPixels(inputDataset.getImgPlus(), null, PixelPredictionType.Probabilities); ImageJFunctions.show(classifiedPixels, "Probability maps"); }