-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Prachi Garg
authored and
Prachi Garg
committed
Nov 12, 2021
1 parent
81afac0
commit d2bd673
Showing
23 changed files
with
4,916 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,63 @@ | ||
# Multi-Domain Incremental Learning for Semantic Segmentation | ||
This is the Pytorch implementation of our work "Multi-Domain Incremental Learning for Semantic Segmentation", accepted at WACV 2022 | ||
This is the Pytorch implementation of our work "Multi-Domain Incremental Learning for Semantic Segmentation", accepted at WACV 2022 (Algorithms Track) | ||
|
||
http://arxiv.org/abs/2110.12205 | ||
Full paper: http://arxiv.org/abs/2110.12205 | ||
|
||
Code coming soon. | ||
 | ||
|
||
## Requirements | ||
|
||
- Python 3.6 | ||
- Pytorch: Make sure to install the Pytorch version for Python 3.6 with CUDA support (This code has been run with CUDA 10.2) | ||
- Models can take upto 40 hours on 2 Nvidia GeForce GTX 1080 Ti GPUs for Step 2; and upto 90 hours on 4 Nvidia GeForce GTX 1080 Ti GPUs for Step 3. | ||
|
||
## Datasets | ||
|
||
- Cityscapes: https://www.cityscapes-dataset.com/ | ||
- BDD100k: https://www.bdd100k.com/ | ||
- IDD: Download IDD Part 1 from https://idd.insaan.iiit.ac.in/ | ||
|
||
**Preprocessing IDD:** convert polygon labels to segmentation masks: | ||
|
||
1. Clone [public-code](https://github.com/AutoNUE/public-code) | ||
2. `export PYTHONPATH='../public-code/helpers/'` | ||
3. `python ../public-code/preperation/createLabels.py --datadir <datadir> --id-type level3Id` | ||
|
||
## Launching the code | ||
|
||
### Training | ||
Training occurs in incremental steps. Model at each subsequent step is initialized from the previous step model. Hence, training in steps 2 and 3 are dependent on previous checkpoints. | ||
|
||
Sample commands for the incremental domain sequence **Cityscapes (CS) -> BDD -> IDD**: | ||
|
||
_Step 1: Learn model on CS_ \ | ||
`python train_RAPFT_step1.py --savedir <savedir> --num-epochs 150 --batch-size 6 --state "trained_models/erfnet_encoder_pretrained.pth.tar" --num-classes 20 --current_task=0 --dataset='cityscapes'` | ||
|
||
_Step 2: Learn CS model on BDD_ \ | ||
`python train_new_task_step2.py --savedir <savedir> --num-epochs 150 --model-name-suffix='ours-CS1-BDD2' --batch-size 6 --state <path_to_Step1_model> --dataset='BDD' --dataset_old='cityscapes' --num-classes 20 20 --current_task=1 --nb_tasks=2 --num-classes-old 20` | ||
|
||
_Step 3: Learn CS|BDD model on IDD_ \ | ||
`python train_new_task_step3.py --savedir <savedir> --num-epochs 150 --model-name-suffix='OURS-CS1-BDD2-IDD3' --batch-size 6 --state "path_to_Step2_model" --dataset-new='IDD' --datasets 'cityscapes' 'BDD' 'IDD' --num-classes 20 20 27 --num-classes-old 20 20 --current_task=2 --nb_tasks=3 --lambdac=0.1` | ||
|
||
Training commands for the Fine-tuning model, Multi-task (joint, offline) model and Single-task (independent models) can be found in the bash scripts inside `trainer_files` directory. Other ablation experiment files can be requested. | ||
|
||
### Pretrained Models | ||
coming soon | ||
#### Testing | ||
#### Tensorboard use | ||
#### T-SNE plots for segmentation | ||
|
||
## Citation | ||
`@article{garg2021multi, | ||
title={Multi-Domain Incremental Learning for Semantic Segmentation}, | ||
author={Garg, Prachi and Saluja, Rohit and Balasubramanian, Vineeth N and Arora, Chetan and Subramanian, Anbumani and Jawahar, CV}, | ||
journal={arXiv preprint arXiv:2110.12205}, | ||
year={2021} | ||
} | ||
` | ||
|
||
## Acknowledgements | ||
- Code was originally borrowed from the ERFNet Pytorch implementation: https://github.com/Eromera/erfnet_pytorch | ||
- Implementation of the residual adapters has been inspired from: | ||
- https://github.com/menelaoskanakis/RCM | ||
- https://github.com/srebuffi/residual_adapters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import numpy as np | ||
import os | ||
|
||
from PIL import Image | ||
|
||
from torch.utils.data import Dataset | ||
|
||
EXTENSIONS = ['.jpg', '.png'] | ||
|
||
|
||
def load_image(file): | ||
return Image.open(file) | ||
|
||
|
||
def is_image(filename): | ||
return any(filename.endswith(ext) for ext in EXTENSIONS) | ||
|
||
|
||
def is_label_city(filename): | ||
return filename.endswith("_labelTrainIds.png") | ||
|
||
|
||
def is_label_IDD(filename): | ||
return filename.endswith("_labellevel3Ids.png") | ||
|
||
|
||
def is_label_BDD(filename): | ||
return filename.endswith("_train_id.png") | ||
|
||
|
||
def image_path(root, basename, extension): | ||
return os.path.join(root, f'{basename}{extension}') | ||
|
||
|
||
def image_path_city(root, name): | ||
return os.path.join(root, f'{name}') | ||
|
||
|
||
def image_basename(filename): | ||
return os.path.basename(os.path.splitext(filename)[0]) | ||
|
||
|
||
class VOC12(Dataset): | ||
|
||
def __init__(self, root, input_transform=None, target_transform=None): | ||
self.images_root = os.path.join(root, 'images') | ||
self.labels_root = os.path.join(root, 'labels') | ||
|
||
self.filenames = [image_basename(f) | ||
for f in os.listdir(self.labels_root) if is_image(f)] | ||
self.filenames.sort() | ||
|
||
self.input_transform = input_transform | ||
self.target_transform = target_transform | ||
|
||
def __getitem__(self, index): | ||
filename = self.filenames[index] | ||
|
||
with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f: | ||
image = load_image(f).convert('RGB') | ||
with open(image_path(self.labels_root, filename, '.png'), 'rb') as f: | ||
label = load_image(f).convert('P') | ||
|
||
if self.input_transform is not None: | ||
image = self.input_transform(image) | ||
if self.target_transform is not None: | ||
label = self.target_transform(label) | ||
|
||
return image, label | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
|
||
class cityscapes(Dataset): | ||
|
||
def __init__(self, root, co_transform=None, subset='train'): | ||
self.images_root = os.path.join(root, 'leftImg8bit/') | ||
self.labels_root = os.path.join(root, 'gtFine/') | ||
|
||
self.images_root += subset | ||
self.labels_root += subset | ||
|
||
print(self.images_root) | ||
#self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)] | ||
self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk( | ||
os.path.expanduser(self.images_root)) for f in fn if is_image(f)] | ||
self.filenames.sort() | ||
|
||
#[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn] | ||
#self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)] | ||
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk( | ||
os.path.expanduser(self.labels_root)) for f in fn if is_label_city(f)] | ||
self.filenamesGt.sort() | ||
|
||
self.co_transform = co_transform # ADDED THIS | ||
|
||
def __getitem__(self, index): | ||
filename = self.filenames[index] | ||
filenameGt = self.filenamesGt[index] | ||
|
||
with open(image_path_city(self.images_root, filename), 'rb') as f: | ||
image = load_image(f).convert('RGB') | ||
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: | ||
label = load_image(f).convert('P') | ||
|
||
if self.co_transform is not None: | ||
image, label = self.co_transform(image, label) | ||
|
||
return image, label | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
# added | ||
|
||
|
||
class IDD(Dataset): | ||
|
||
def __init__(self, root, co_transform=None, subset='train'): | ||
self.images_root = os.path.join(root, 'leftImg8bit/') | ||
self.labels_root = os.path.join(root, 'gtFine/') | ||
|
||
self.images_root += subset | ||
self.labels_root += subset | ||
|
||
print(self.images_root) | ||
#self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)] | ||
self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk( | ||
os.path.expanduser(self.images_root)) for f in fn if is_image(f)] | ||
self.filenames.sort() | ||
|
||
#[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn] | ||
#self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)] | ||
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk( | ||
os.path.expanduser(self.labels_root)) for f in fn if is_label_IDD(f)] | ||
self.filenamesGt.sort() | ||
|
||
self.co_transform = co_transform # ADDED THIS | ||
|
||
def __getitem__(self, index): | ||
filename = self.filenames[index] | ||
filenameGt = self.filenamesGt[index] | ||
|
||
# image_path_city will work for IDD also as the images already have a .png extension | ||
with open(image_path_city(self.images_root, filename), 'rb') as f: | ||
image = load_image(f).convert('RGB') | ||
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: | ||
label = load_image(f).convert('P') | ||
|
||
if self.co_transform is not None: | ||
image, label = self.co_transform(image, label) | ||
|
||
return image, label | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
# added | ||
|
||
|
||
class IDD_union(Dataset): | ||
|
||
def __init__(self, root, co_transform=None, subset='train'): | ||
self.images_root = os.path.join(root, 'leftImg8bit/') | ||
self.labels_root = os.path.join(root, 'gtFine/') | ||
|
||
self.images_root += subset | ||
self.labels_root += subset | ||
|
||
print(self.images_root) | ||
|
||
MAP_dict = {0: 0, 1: 19, 2: 1, 3: 20, 4: 11, 5: 12, 6: 17, 7: 18, 8: 21, 9: 13, 10: 14, 11: 15, 12: 22, 13: 23, 14: 3, | ||
15: 4, 16: 24, 17: 25, 18: 7, 19: 6, 20: 5, 21: 26, 22: 2, 23: 27, 24: 8, 25: 10, 255: 255} | ||
self.k = np.array(list(MAP_dict.keys())) | ||
self.v = np.array(list(MAP_dict.values())) | ||
print('mapping keys:', self.k) | ||
print('mapped values: ', self.v) | ||
|
||
#self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)] | ||
self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk( | ||
os.path.expanduser(self.images_root)) for f in fn if is_image(f)] | ||
self.filenames.sort() | ||
|
||
#[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn] | ||
#self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)] | ||
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk( | ||
os.path.expanduser(self.labels_root)) for f in fn if is_label_IDD(f)] | ||
self.filenamesGt.sort() | ||
|
||
self.co_transform = co_transform # ADDED THIS | ||
|
||
def __getitem__(self, index): | ||
filename = self.filenames[index] | ||
filenameGt = self.filenamesGt[index] | ||
|
||
# image_path_city will work for IDD also as the images already have a .png extension | ||
with open(image_path_city(self.images_root, filename), 'rb') as f: | ||
image = load_image(f).convert('RGB') | ||
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: | ||
label = load_image(f).convert('P') | ||
|
||
label = np.array(label) | ||
mapping_ar = np.zeros(self.k.max()+1, dtype=self.v.dtype) | ||
mapping_ar[self.k] = self.v # dict bananyi, k mei dalo, v mei chala jaega | ||
label = mapping_ar[label] # map label to new values of IDD label | ||
label = Image.fromarray(np.uint8(label)) | ||
|
||
if self.co_transform is not None: | ||
image, label = self.co_transform(image, label) | ||
|
||
return image, label | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
# added | ||
|
||
|
||
class BDD100k(Dataset): | ||
|
||
def __init__(self, root, co_transform=None, subset='train'): | ||
self.images_root = os.path.join(root, 'images/') | ||
self.labels_root = os.path.join(root, 'labels/') | ||
|
||
self.images_root += subset | ||
self.labels_root += subset | ||
|
||
print(self.images_root) | ||
#self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)] | ||
self.filenames = [f for f in os.listdir(self.images_root) if is_image(f)] | ||
self.filenames.sort() | ||
|
||
#[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn] | ||
#self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)] | ||
self.filenamesGt = [fn for fn in os.listdir(self.labels_root) if is_label_BDD(fn)] | ||
self.filenamesGt.sort() | ||
|
||
self.co_transform = co_transform # ADDED THIS | ||
|
||
def __getitem__(self, index): | ||
filename = self.filenames[index] | ||
filenameGt = self.filenamesGt[index] | ||
|
||
with open(image_path_city(self.images_root, filename), 'rb') as f: | ||
image = load_image(f).convert('RGB') | ||
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: | ||
label = load_image(f).convert('P') | ||
|
||
if self.co_transform is not None: | ||
image, label = self.co_transform(image, label) | ||
|
||
return image, label | ||
|
||
def __len__(self): | ||
return len(self.filenames) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.