Skip to content

Commit

Permalink
added training code
Browse files Browse the repository at this point in the history
  • Loading branch information
Prachi Garg authored and Prachi Garg committed Nov 12, 2021
1 parent 81afac0 commit d2bd673
Show file tree
Hide file tree
Showing 23 changed files with 4,916 additions and 3 deletions.
Binary file added .DS_Store
Binary file not shown.
63 changes: 60 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,63 @@
# Multi-Domain Incremental Learning for Semantic Segmentation
This is the Pytorch implementation of our work "Multi-Domain Incremental Learning for Semantic Segmentation", accepted at WACV 2022
This is the Pytorch implementation of our work "Multi-Domain Incremental Learning for Semantic Segmentation", accepted at WACV 2022 (Algorithms Track)

http://arxiv.org/abs/2110.12205
Full paper: http://arxiv.org/abs/2110.12205

Code coming soon.
![image](final-main-diagram-wacv1.png)

## Requirements

- Python 3.6
- Pytorch: Make sure to install the Pytorch version for Python 3.6 with CUDA support (This code has been run with CUDA 10.2)
- Models can take upto 40 hours on 2 Nvidia GeForce GTX 1080 Ti GPUs for Step 2; and upto 90 hours on 4 Nvidia GeForce GTX 1080 Ti GPUs for Step 3.

## Datasets

- Cityscapes: https://www.cityscapes-dataset.com/
- BDD100k: https://www.bdd100k.com/
- IDD: Download IDD Part 1 from https://idd.insaan.iiit.ac.in/

**Preprocessing IDD:** convert polygon labels to segmentation masks:

1. Clone [public-code](https://github.com/AutoNUE/public-code)
2. `export PYTHONPATH='../public-code/helpers/'`
3. `python ../public-code/preperation/createLabels.py --datadir <datadir> --id-type level3Id`

## Launching the code

### Training
Training occurs in incremental steps. Model at each subsequent step is initialized from the previous step model. Hence, training in steps 2 and 3 are dependent on previous checkpoints.

Sample commands for the incremental domain sequence **Cityscapes (CS) -> BDD -> IDD**:

_Step 1: Learn model on CS_ \
`python train_RAPFT_step1.py --savedir <savedir> --num-epochs 150 --batch-size 6 --state "trained_models/erfnet_encoder_pretrained.pth.tar" --num-classes 20 --current_task=0 --dataset='cityscapes'`

_Step 2: Learn CS model on BDD_ \
`python train_new_task_step2.py --savedir <savedir> --num-epochs 150 --model-name-suffix='ours-CS1-BDD2' --batch-size 6 --state <path_to_Step1_model> --dataset='BDD' --dataset_old='cityscapes' --num-classes 20 20 --current_task=1 --nb_tasks=2 --num-classes-old 20`

_Step 3: Learn CS|BDD model on IDD_ \
`python train_new_task_step3.py --savedir <savedir> --num-epochs 150 --model-name-suffix='OURS-CS1-BDD2-IDD3' --batch-size 6 --state "path_to_Step2_model" --dataset-new='IDD' --datasets 'cityscapes' 'BDD' 'IDD' --num-classes 20 20 27 --num-classes-old 20 20 --current_task=2 --nb_tasks=3 --lambdac=0.1`

Training commands for the Fine-tuning model, Multi-task (joint, offline) model and Single-task (independent models) can be found in the bash scripts inside `trainer_files` directory. Other ablation experiment files can be requested.

### Pretrained Models
coming soon
#### Testing
#### Tensorboard use
#### T-SNE plots for segmentation

## Citation
`@article{garg2021multi,
title={Multi-Domain Incremental Learning for Semantic Segmentation},
author={Garg, Prachi and Saluja, Rohit and Balasubramanian, Vineeth N and Arora, Chetan and Subramanian, Anbumani and Jawahar, CV},
journal={arXiv preprint arXiv:2110.12205},
year={2021}
}
`

## Acknowledgements
- Code was originally borrowed from the ERFNet Pytorch implementation: https://github.com/Eromera/erfnet_pytorch
- Implementation of the residual adapters has been inspired from:
- https://github.com/menelaoskanakis/RCM
- https://github.com/srebuffi/residual_adapters
256 changes: 256 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import numpy as np
import os

from PIL import Image

from torch.utils.data import Dataset

EXTENSIONS = ['.jpg', '.png']


def load_image(file):
return Image.open(file)


def is_image(filename):
return any(filename.endswith(ext) for ext in EXTENSIONS)


def is_label_city(filename):
return filename.endswith("_labelTrainIds.png")


def is_label_IDD(filename):
return filename.endswith("_labellevel3Ids.png")


def is_label_BDD(filename):
return filename.endswith("_train_id.png")


def image_path(root, basename, extension):
return os.path.join(root, f'{basename}{extension}')


def image_path_city(root, name):
return os.path.join(root, f'{name}')


def image_basename(filename):
return os.path.basename(os.path.splitext(filename)[0])


class VOC12(Dataset):

def __init__(self, root, input_transform=None, target_transform=None):
self.images_root = os.path.join(root, 'images')
self.labels_root = os.path.join(root, 'labels')

self.filenames = [image_basename(f)
for f in os.listdir(self.labels_root) if is_image(f)]
self.filenames.sort()

self.input_transform = input_transform
self.target_transform = target_transform

def __getitem__(self, index):
filename = self.filenames[index]

with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f:
image = load_image(f).convert('RGB')
with open(image_path(self.labels_root, filename, '.png'), 'rb') as f:
label = load_image(f).convert('P')

if self.input_transform is not None:
image = self.input_transform(image)
if self.target_transform is not None:
label = self.target_transform(label)

return image, label

def __len__(self):
return len(self.filenames)


class cityscapes(Dataset):

def __init__(self, root, co_transform=None, subset='train'):
self.images_root = os.path.join(root, 'leftImg8bit/')
self.labels_root = os.path.join(root, 'gtFine/')

self.images_root += subset
self.labels_root += subset

print(self.images_root)
#self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)]
self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(
os.path.expanduser(self.images_root)) for f in fn if is_image(f)]
self.filenames.sort()

#[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn]
#self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)]
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(
os.path.expanduser(self.labels_root)) for f in fn if is_label_city(f)]
self.filenamesGt.sort()

self.co_transform = co_transform # ADDED THIS

def __getitem__(self, index):
filename = self.filenames[index]
filenameGt = self.filenamesGt[index]

with open(image_path_city(self.images_root, filename), 'rb') as f:
image = load_image(f).convert('RGB')
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f:
label = load_image(f).convert('P')

if self.co_transform is not None:
image, label = self.co_transform(image, label)

return image, label

def __len__(self):
return len(self.filenames)

# added


class IDD(Dataset):

def __init__(self, root, co_transform=None, subset='train'):
self.images_root = os.path.join(root, 'leftImg8bit/')
self.labels_root = os.path.join(root, 'gtFine/')

self.images_root += subset
self.labels_root += subset

print(self.images_root)
#self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)]
self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(
os.path.expanduser(self.images_root)) for f in fn if is_image(f)]
self.filenames.sort()

#[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn]
#self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)]
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(
os.path.expanduser(self.labels_root)) for f in fn if is_label_IDD(f)]
self.filenamesGt.sort()

self.co_transform = co_transform # ADDED THIS

def __getitem__(self, index):
filename = self.filenames[index]
filenameGt = self.filenamesGt[index]

# image_path_city will work for IDD also as the images already have a .png extension
with open(image_path_city(self.images_root, filename), 'rb') as f:
image = load_image(f).convert('RGB')
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f:
label = load_image(f).convert('P')

if self.co_transform is not None:
image, label = self.co_transform(image, label)

return image, label

def __len__(self):
return len(self.filenames)

# added


class IDD_union(Dataset):

def __init__(self, root, co_transform=None, subset='train'):
self.images_root = os.path.join(root, 'leftImg8bit/')
self.labels_root = os.path.join(root, 'gtFine/')

self.images_root += subset
self.labels_root += subset

print(self.images_root)

MAP_dict = {0: 0, 1: 19, 2: 1, 3: 20, 4: 11, 5: 12, 6: 17, 7: 18, 8: 21, 9: 13, 10: 14, 11: 15, 12: 22, 13: 23, 14: 3,
15: 4, 16: 24, 17: 25, 18: 7, 19: 6, 20: 5, 21: 26, 22: 2, 23: 27, 24: 8, 25: 10, 255: 255}
self.k = np.array(list(MAP_dict.keys()))
self.v = np.array(list(MAP_dict.values()))
print('mapping keys:', self.k)
print('mapped values: ', self.v)

#self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)]
self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(
os.path.expanduser(self.images_root)) for f in fn if is_image(f)]
self.filenames.sort()

#[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn]
#self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)]
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(
os.path.expanduser(self.labels_root)) for f in fn if is_label_IDD(f)]
self.filenamesGt.sort()

self.co_transform = co_transform # ADDED THIS

def __getitem__(self, index):
filename = self.filenames[index]
filenameGt = self.filenamesGt[index]

# image_path_city will work for IDD also as the images already have a .png extension
with open(image_path_city(self.images_root, filename), 'rb') as f:
image = load_image(f).convert('RGB')
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f:
label = load_image(f).convert('P')

label = np.array(label)
mapping_ar = np.zeros(self.k.max()+1, dtype=self.v.dtype)
mapping_ar[self.k] = self.v # dict bananyi, k mei dalo, v mei chala jaega
label = mapping_ar[label] # map label to new values of IDD label
label = Image.fromarray(np.uint8(label))

if self.co_transform is not None:
image, label = self.co_transform(image, label)

return image, label

def __len__(self):
return len(self.filenames)

# added


class BDD100k(Dataset):

def __init__(self, root, co_transform=None, subset='train'):
self.images_root = os.path.join(root, 'images/')
self.labels_root = os.path.join(root, 'labels/')

self.images_root += subset
self.labels_root += subset

print(self.images_root)
#self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)]
self.filenames = [f for f in os.listdir(self.images_root) if is_image(f)]
self.filenames.sort()

#[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn]
#self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)]
self.filenamesGt = [fn for fn in os.listdir(self.labels_root) if is_label_BDD(fn)]
self.filenamesGt.sort()

self.co_transform = co_transform # ADDED THIS

def __getitem__(self, index):
filename = self.filenames[index]
filenameGt = self.filenamesGt[index]

with open(image_path_city(self.images_root, filename), 'rb') as f:
image = load_image(f).convert('RGB')
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f:
label = load_image(f).convert('P')

if self.co_transform is not None:
image, label = self.co_transform(image, label)

return image, label

def __len__(self):
return len(self.filenames)
Binary file added final-main-diagram-wacv1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit d2bd673

Please sign in to comment.