Skip to content

Commit

Permalink
Style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Dec 5, 2023
1 parent 9ba7b6b commit bd84184
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
16 changes: 8 additions & 8 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,14 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
dataloader = self.predict_dataloader(ds)

results = predict._dataloader_wrapper_(model=self,
trainer=self.trainer,
annotations=df,
dataloader=dataloader,
root_dir=root_dir,
nms_thresh=self.config["nms_thresh"],
savedir=savedir,
color=color,
thickness=thickness)
trainer=self.trainer,
annotations=df,
dataloader=dataloader,
root_dir=root_dir,
nms_thresh=self.config["nms_thresh"],
savedir=savedir,
color=color,
thickness=thickness)

return results

Expand Down
16 changes: 8 additions & 8 deletions deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ def across_class_nms(predicted_boxes, iou_threshold=0.15):


def _dataloader_wrapper_(model,
trainer,
dataloader,
root_dir,
annotations,
nms_thresh,
savedir=None,
color=None,
thickness=1):
trainer,
dataloader,
root_dir,
annotations,
nms_thresh,
savedir=None,
color=None,
thickness=1):
"""Create a dataset and predict entire annotation file
Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
Expand Down
14 changes: 7 additions & 7 deletions deepforest/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def split_raster(annotations_file,
patch_overlap=0.05,
allow_empty=False,
image_name=None,
save_dir="."
):
save_dir="."):
"""Divide a large tile into smaller arrays. Each crop will be saved to
file.
Expand Down Expand Up @@ -190,15 +189,16 @@ def split_raster(annotations_file,
# Confirm that raster is H x W x C, if not, convert, assuming image is wider/taller than channels
if numpy_image.shape[0] < numpy_image.shape[-1]:
warnings.warn(
"Input rasterio had shape {}, assuming channels first. Converting to channels last".format(
numpy_image.shape), UserWarning)
"Input rasterio had shape {}, assuming channels first. Converting to channels last"
.format(numpy_image.shape), UserWarning)
numpy_image = np.moveaxis(numpy_image, 0, 2)

# Check that its 3 band
bands = numpy_image.shape[2]
if not bands == 3:
warnings.warn("Input rasterio had non-3 band shape of {}, ignoring "
"alpha channel".format(numpy_image.shape), UserWarning)
warnings.warn(
"Input rasterio had non-3 band shape of {}, ignoring "
"alpha channel".format(numpy_image.shape), UserWarning)
try:
numpy_image = numpy_image[:, :, :3].astype("uint8")
except:
Expand Down

0 comments on commit bd84184

Please sign in to comment.