Skip to content

Commit

Permalink
Merge pull request #4 from MalondaClement/test_fix_tunnel
Browse files Browse the repository at this point in the history
Fix tunnel dataset
  • Loading branch information
MalondaClement authored Nov 5, 2021
2 parents a5e9479 + 902612d commit a9be04f
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ You can read the documentation [here](https://github.com/MalondaClement/pipeline

## 2. How to use the pipeline

```python
```bash
git clone https://github.com/MalondaClement/pipeline.git
```

Expand Down
4 changes: 3 additions & 1 deletion datasets/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, root, split="train", labels_type="csv", transform=None, targe
assert split in ["train","val","test"], "Unknown value {} for argument split.".format(split)
if self.labels_type == "json":
self.target_dir = os.path.join(self.root, "jsons")
self.__read_csv()
self.__read_json()
elif self.labels_type == "csv":
self.target_dir = os.path.join(self.root, "csvs")
self.__read_csv()
Expand All @@ -62,6 +62,8 @@ def __getitem__(self, index):
draw.polygon(e[0], fill=classToVal[e[1]])
image = np.array(image)
image = image.transpose(2, 0, 1)
# int to float to fix training
image = image/255
target = np.array(target)[:, :, 0]
return image, target, filepath

Expand Down
8 changes: 5 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
args = ARGS("DeepLabV3_Resnet50", "tunnel", len(Dataset.validClasses), labels_type="csv", batch_size=2, epochs=2)

model, args = get_model(args)
args.save_path = "path"
args.save_path = "../"

checkpoint = torch.load(os.path.join(args.save_path, "best_weights.pth.tar"), map_location=torch.device('cpu'))

Expand All @@ -44,10 +44,12 @@ def main():
times = list()

start = time.time()
for i, file in enumerate(os.listdir("images_inf_path")):
for i, file in enumerate(os.listdir("/Users/ClementMalonda/Desktop/img_inf")):
if file[-4:] != ".png":
continue
start = time.time()

img = Image.open(os.path.join("images_inf_path",file))
img = Image.open(os.path.join("/Users/ClementMalonda/Desktop/img_inf",file))
img = np.array(img)
img = img[:,:,:3]
img = img/255
Expand Down
27 changes: 27 additions & 0 deletions learning/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ def train_epoch(dataloader, model, criterion, optimizer, lr_scheduler, epoch, va
# Iterate over data.
for epoch_step, (inputs, labels, _) in enumerate(dataloader):
data_time.update(time.time()-end)

#test
fig, (ax0, ax1, ax2) = plt.subplots(1, 3)
images = inputs.numpy()
image = images[0, :, :, :]
image = image.transpose(1, 2, 0)
ax0.imshow(image)
ax2.imshow(image)
ax0.set_title("Image d'origine")
print(type(inputs))
print(inputs.shape)
#end test

if args.copyblob:
for i in range(inputs.size()[0]):
rand_idx = np.random.randint(inputs.size()[0])
Expand Down Expand Up @@ -72,6 +85,20 @@ def train_epoch(dataloader, model, criterion, optimizer, lr_scheduler, epoch, va
if args.is_pytorch_model :
outputs = outputs['out'] #FIXME for DeepLab V3
preds = torch.argmax(outputs, 1)

#test
print(type(preds))
print(preds.shape)
pred = preds[0, :, :].cpu()
ax1.imshow(pred)
ax1.set_title("Prédiction")
ax2.imshow(pred, alpha=0.5)
ax2.set_title("Superposition de l'image avec la prédiction")
if not os.path.isdir(os.path.join(args.save_path, "inference")):
os.makedirs(os.path.join(args.save_path, "inference"))
fig.savefig(os.path.join(args.save_path, "inference", str(epoch)+".png"))
# end test

# cross-entropy loss
loss = criterion(outputs, labels)

Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from helpers.helpers import plot_learning_curves
from learning.learner import train_epoch, validate_epoch
from learning.utils import get_dataloader
# from datasets.tunnel import Tunnel
from datasets.minicity import MiniCity
from datasets.tunnel import Tunnel
# from datasets.minicity import MiniCity

def main():
# Get tunnel dataset
Dataset = MiniCity
Dataset = Tunnel

# Set up execution arguments
args = ARGS("DeepLabV3_MobileNetV3", "microcity", len(Dataset.validClasses), labels_type="csv", batch_size=2, epochs=40)
args = ARGS("DeepLabV3_Resnet50", "tunnel", len(Dataset.validClasses), labels_type="json", batch_size=4, epochs=100)

# Get model
model, args = get_model(args)
Expand Down

0 comments on commit a9be04f

Please sign in to comment.