diff --git a/README.md b/README.md index 41a076b..2be2c79 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ You can read the documentation [here](https://github.com/MalondaClement/pipeline ## 2. How to use the pipeline -```python +```bash git clone https://github.com/MalondaClement/pipeline.git ``` diff --git a/datasets/tunnel.py b/datasets/tunnel.py index 7f4c7c7..316658f 100644 --- a/datasets/tunnel.py +++ b/datasets/tunnel.py @@ -43,7 +43,7 @@ def __init__(self, root, split="train", labels_type="csv", transform=None, targe assert split in ["train","val","test"], "Unknown value {} for argument split.".format(split) if self.labels_type == "json": self.target_dir = os.path.join(self.root, "jsons") - self.__read_csv() + self.__read_json() elif self.labels_type == "csv": self.target_dir = os.path.join(self.root, "csvs") self.__read_csv() @@ -62,6 +62,8 @@ def __getitem__(self, index): draw.polygon(e[0], fill=classToVal[e[1]]) image = np.array(image) image = image.transpose(2, 0, 1) + # int to float to fix training + image = image/255 target = np.array(target)[:, :, 0] return image, target, filepath diff --git a/inference.py b/inference.py index 7de8491..fcc59c3 100644 --- a/inference.py +++ b/inference.py @@ -24,7 +24,7 @@ def main(): args = ARGS("DeepLabV3_Resnet50", "tunnel", len(Dataset.validClasses), labels_type="csv", batch_size=2, epochs=2) model, args = get_model(args) - args.save_path = "path" + args.save_path = "../" checkpoint = torch.load(os.path.join(args.save_path, "best_weights.pth.tar"), map_location=torch.device('cpu')) @@ -44,10 +44,12 @@ def main(): times = list() start = time.time() - for i, file in enumerate(os.listdir("images_inf_path")): + for i, file in enumerate(os.listdir("/Users/ClementMalonda/Desktop/img_inf")): + if file[-4:] != ".png": + continue start = time.time() - img = Image.open(os.path.join("images_inf_path",file)) + img = Image.open(os.path.join("/Users/ClementMalonda/Desktop/img_inf",file)) img = np.array(img) img = img[:,:,:3] img = img/255 diff --git a/learning/learner.py b/learning/learner.py index 2fcba69..e5e3d12 100644 --- a/learning/learner.py +++ b/learning/learner.py @@ -41,6 +41,19 @@ def train_epoch(dataloader, model, criterion, optimizer, lr_scheduler, epoch, va # Iterate over data. for epoch_step, (inputs, labels, _) in enumerate(dataloader): data_time.update(time.time()-end) + + #test + fig, (ax0, ax1, ax2) = plt.subplots(1, 3) + images = inputs.numpy() + image = images[0, :, :, :] + image = image.transpose(1, 2, 0) + ax0.imshow(image) + ax2.imshow(image) + ax0.set_title("Image d'origine") + print(type(inputs)) + print(inputs.shape) + #end test + if args.copyblob: for i in range(inputs.size()[0]): rand_idx = np.random.randint(inputs.size()[0]) @@ -72,6 +85,20 @@ def train_epoch(dataloader, model, criterion, optimizer, lr_scheduler, epoch, va if args.is_pytorch_model : outputs = outputs['out'] #FIXME for DeepLab V3 preds = torch.argmax(outputs, 1) + + #test + print(type(preds)) + print(preds.shape) + pred = preds[0, :, :].cpu() + ax1.imshow(pred) + ax1.set_title("Prédiction") + ax2.imshow(pred, alpha=0.5) + ax2.set_title("Superposition de l'image avec la prédiction") + if not os.path.isdir(os.path.join(args.save_path, "inference")): + os.makedirs(os.path.join(args.save_path, "inference")) + fig.savefig(os.path.join(args.save_path, "inference", str(epoch)+".png")) + # end test + # cross-entropy loss loss = criterion(outputs, labels) diff --git a/train.py b/train.py index b77e84d..0b7a05f 100644 --- a/train.py +++ b/train.py @@ -14,15 +14,15 @@ from helpers.helpers import plot_learning_curves from learning.learner import train_epoch, validate_epoch from learning.utils import get_dataloader -# from datasets.tunnel import Tunnel -from datasets.minicity import MiniCity +from datasets.tunnel import Tunnel +# from datasets.minicity import MiniCity def main(): # Get tunnel dataset - Dataset = MiniCity + Dataset = Tunnel # Set up execution arguments - args = ARGS("DeepLabV3_MobileNetV3", "microcity", len(Dataset.validClasses), labels_type="csv", batch_size=2, epochs=40) + args = ARGS("DeepLabV3_Resnet50", "tunnel", len(Dataset.validClasses), labels_type="json", batch_size=4, epochs=100) # Get model model, args = get_model(args)