Skip to content

Commit

Permalink
fix inference
Browse files Browse the repository at this point in the history
  • Loading branch information
MalondaClement committed Oct 28, 2021
1 parent e336c17 commit 902612d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ You can read the documentation [here](https://github.com/MalondaClement/pipeline

## 2. How to use the pipeline

```python
```bash
git clone https://github.com/MalondaClement/pipeline.git
```

Expand Down
10 changes: 7 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
args = ARGS("DeepLabV3_Resnet50", "tunnel", len(Dataset.validClasses), labels_type="csv", batch_size=2, epochs=2)

model, args = get_model(args)
args.save_path = "path"
args.save_path = "../"

checkpoint = torch.load(os.path.join(args.save_path, "best_weights.pth.tar"), map_location=torch.device('cpu'))

Expand All @@ -44,15 +44,19 @@ def main():
times = list()

start = time.time()
for i, file in enumerate(os.listdir("images_inf_path")):
for i, file in enumerate(os.listdir("/Users/ClementMalonda/Desktop/img_inf")):
if file[-4:] != ".png":
continue
start = time.time()

img = Image.open(os.path.join("images_inf_path",file))
img = Image.open(os.path.join("/Users/ClementMalonda/Desktop/img_inf",file))
img = np.array(img)
img = img[:,:,:3]
img = img/255

input = ToTensor()(img)
input = input.unsqueeze(0)
input = input.float()
with torch.no_grad():
output = model(input)
if args.is_pytorch_model:
Expand Down

0 comments on commit 902612d

Please sign in to comment.