diff --git a/examples/lifelong_learning/RFNet/accuracy.py b/examples/lifelong_learning/RFNet/accuracy.py new file mode 100644 index 000000000..8d356fed1 --- /dev/null +++ b/examples/lifelong_learning/RFNet/accuracy.py @@ -0,0 +1,38 @@ +from basemodel import val_args +from utils.metrics import Evaluator +from tqdm import tqdm +from dataloaders import make_data_loader +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ('accuracy') + +@ClassFactory.register(ClassType.GENERAL) +def accuracy(y_true, y_pred, **kwargs): + args = val_args() + _, _, test_loader, num_class = make_data_loader(args, test_data=y_true) + evaluator = Evaluator(num_class) + + tbar = tqdm(test_loader, desc='\r') + for i, (sample, img_path) in enumerate(tbar): + if args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + if args.cuda: + image, target = image.cuda(), target.cuda() + if args.depth: + depth = depth.cuda() + + target[target > evaluator.num_class-1] = 255 + target = target.cpu().numpy() + # Add batch sample into evaluator + evaluator.add_batch(target, y_pred[i]) + + # Test during the training + # Acc = evaluator.Pixel_Accuracy() + CPA = evaluator.Pixel_Accuracy_Class() + mIoU = evaluator.Mean_Intersection_over_Union() + FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() + + print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU)) + return CPA diff --git a/examples/lifelong_learning/RFNet/basemodel.py b/examples/lifelong_learning/RFNet/basemodel.py new file mode 100644 index 000000000..dba4cfdf2 --- /dev/null +++ b/examples/lifelong_learning/RFNet/basemodel.py @@ -0,0 +1,315 @@ +import os +import numpy as np +import torch +from PIL import Image +import argparse +from train import Trainer +from eval import Validator +from tqdm import tqdm +from eval import load_my_state_dict +from utils.metrics import Evaluator +from dataloaders import make_data_loader +from dataloaders import custom_transforms as tr +from torchvision import transforms +from sedna.common.class_factory import ClassType, ClassFactory +from sedna.common.config import Context +from sedna.datasources import TxtDataParse +from torch.utils.data import DataLoader +from sedna.common.file_ops import FileOps +from utils.lr_scheduler import LR_Scheduler + +def preprocess(image_urls): + transformed_images = [] + for paths in image_urls: + if len(paths) == 2: + img_path, depth_path = paths + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(depth_path) + else: + img_path = paths[0] + _img = Image.open(img_path).convert('RGB') + _depth = _img + + sample = {'image': _img, 'depth': _depth, 'label': _img} + composed_transforms = transforms.Compose([ + # tr.CropBlackArea(), + # tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + transformed_images.append((composed_transforms(sample), img_path)) + + return transformed_images + +class Model: + def __init__(self, **kwargs): + self.val_args = val_args() + self.train_args = train_args() + + self.train_args.lr = kwargs.get("learning_rate", 1e-4) + self.train_args.epochs = kwargs.get("epochs", 2) + self.train_args.eval_interval = kwargs.get("eval_interval", 2) + self.train_args.no_val = kwargs.get("no_val", True) + # self.train_args.resume = Context.get_parameters("PRETRAINED_MODEL_URL", None) + self.trainer = None + + label_save_dir = Context.get_parameters("INFERENCE_RESULT_DIR", "./inference_results") + self.val_args.color_label_save_path = os.path.join(label_save_dir, "color") + self.val_args.merge_label_save_path = os.path.join(label_save_dir, "merge") + self.val_args.label_save_path = os.path.join(label_save_dir, "label") + self.validator = Validator(self.val_args) + + def train(self, train_data, valid_data=None, **kwargs): + self.trainer = Trainer(self.train_args, train_data=train_data) + print("Total epoches:", self.trainer.args.epochs) + for epoch in range(self.trainer.args.start_epoch, self.trainer.args.epochs): + if epoch == 0 and self.trainer.val_loader: + self.trainer.validation(epoch) + self.trainer.training(epoch) + + if self.trainer.args.no_val and \ + (epoch % self.trainer.args.eval_interval == (self.trainer.args.eval_interval - 1) + or epoch == self.trainer.args.epochs - 1): + # save checkpoint when it meets eval_interval or the training finished + is_best = False + checkpoint_path = self.trainer.saver.save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': self.trainer.model.state_dict(), + 'optimizer': self.trainer.optimizer.state_dict(), + 'best_pred': self.trainer.best_pred, + }, is_best) + + # if not self.trainer.args.no_val and \ + # epoch % self.train_args.eval_interval == (self.train_args.eval_interval - 1) \ + # and self.trainer.val_loader: + # self.trainer.validation(epoch) + + self.trainer.writer.close() + + return checkpoint_path + + def predict(self, data, **kwargs): + if not isinstance(data[0][0], dict): + data = preprocess(data) + + if type(data) is np.ndarray: + data = data.tolist() + + self.validator.test_loader = DataLoader(data, batch_size=self.val_args.test_batch_size, shuffle=False, + pin_memory=True) + return self.validator.validate() + + def evaluate(self, data, **kwargs): + self.val_args.save_predicted_image = kwargs.get("save_predicted_image", True) + samples = preprocess(data.x) + predictions = self.predict(samples) + return accuracy(data.y, predictions) + + def load(self, model_url, **kwargs): + if model_url: + self.validator.new_state_dict = torch.load(model_url, map_location=torch.device("cpu")) + self.train_args.resume = model_url + else: + raise Exception("model url does not exist.") + self.validator.model = load_my_state_dict(self.validator.model, self.validator.new_state_dict['state_dict']) + + def save(self, model_path=None): + # TODO: how to save unstructured data model + pass + +def train_args(): + parser = argparse.ArgumentParser(description="PyTorch RFNet Training") + parser.add_argument('--depth', action="store_true", default=False, + help='training with depth image or not (default: False)') + parser.add_argument('--dataset', type=str, default='cityscapes', + choices=['citylostfound', 'cityscapes', 'cityrand', 'target', 'xrlab', 'e1', 'mapillary'], + help='dataset name (default: cityscapes)') + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='dataloader threads') + parser.add_argument('--base-size', type=int, default=1024, + help='base image size') + parser.add_argument('--crop-size', type=int, default=768, + help='crop image size') + parser.add_argument('--loss-type', type=str, default='ce', + choices=['ce', 'focal'], + help='loss func type (default: ce)') + # training hyper params + # parser.add_argument('--epochs', type=int, default=None, metavar='N', + # help='number of epochs to train (default: auto)') + parser.add_argument('--epochs', type=int, default=None, metavar='N', + help='number of epochs to train (default: auto)') + parser.add_argument('--start_epoch', type=int, default=0, + metavar='N', help='start epochs (default:0)') + parser.add_argument('--batch-size', type=int, default=None, + metavar='N', help='input batch size for \ + training (default: auto)') + parser.add_argument('--val-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--test-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--use-balanced-weights', action='store_true', default=False, + help='whether to use balanced weights (default: True)') + parser.add_argument('--num-class', type=int, default=24, + help='number of training classes (default: 24') + # optimizer params + parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', + help='learning rate (default: auto)') + parser.add_argument('--lr-scheduler', type=str, default='cos', + choices=['poly', 'step', 'cos', 'inv'], + help='lr scheduler mode: (default: cos)') + parser.add_argument('--momentum', type=float, default=0.9, + metavar='M', help='momentum (default: 0.9)') + parser.add_argument('--weight-decay', type=float, default=2.5e-5, + metavar='M', help='w-decay (default: 5e-4)') + # cuda, seed and logging + parser.add_argument('--no-cuda', action='store_true', default= + False, help='disables CUDA training') + parser.add_argument('--gpu-ids', type=str, default='0', + help='use which gpu to train, must be a \ + comma-separated list of integers only (default=0)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + # checking point + parser.add_argument('--resume', type=str, + default=None, + help='put the path to resuming file if needed') + parser.add_argument('--checkname', type=str, default=None, + help='set the checkpoint name') + # finetuning pre-trained models + parser.add_argument('--ft', action='store_true', default=True, + help='finetuning on a different dataset') + # evaluation option + parser.add_argument('--eval-interval', type=int, default=1, + help='evaluation interval (default: 1)') + parser.add_argument('--no-val', action='store_true', default=False, + help='skip validation during training') + + args = parser.parse_args() + args.cuda = not args.no_cuda and torch.cuda.is_available() + print(torch.cuda.is_available()) + if args.cuda: + try: + args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] + except ValueError: + raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') + + if args.epochs is None: + epoches = { + 'cityscapes': 200, + 'citylostfound': 200, + } + args.epochs = epoches[args.dataset.lower()] + + if args.batch_size is None: + args.batch_size = 4 * len(args.gpu_ids) + + if args.test_batch_size is None: + args.test_batch_size = args.batch_size + + if args.lr is None: + lrs = { + 'cityscapes': 0.0001, + 'citylostfound': 0.0001, + 'cityrand': 0.0001 + } + args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size + + if args.checkname is None: + args.checkname = 'RFNet' + print(args) + torch.manual_seed(args.seed) + + return args + +def val_args(): + parser = argparse.ArgumentParser(description="PyTorch RFNet validation") + parser.add_argument('--dataset', type=str, default='cityscapes', + choices=['citylostfound', 'cityscapes', 'xrlab', 'mapillary'], + help='dataset name (default: cityscapes)') + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='dataloader threads') + parser.add_argument('--base-size', type=int, default=1024, + help='base image size') + parser.add_argument('--crop-size', type=int, default=768, + help='crop image size') + parser.add_argument('--batch-size', type=int, default=6, + help='batch size for training') + parser.add_argument('--val-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + validating (default: auto)') + parser.add_argument('--test-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--num-class', type=int, default=24, + help='number of training classes (default: 24') + parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') + parser.add_argument('--gpu-ids', type=str, default='0', + help='use which gpu to train, must be a \ + comma-separated list of integers only (default=0)') + parser.add_argument('--checkname', type=str, default=None, + help='set the checkpoint name') + parser.add_argument('--weight-path', type=str, default="./models/530_exp3_2.pth", + help='enter your path of the weight') + parser.add_argument('--save-predicted-image', action='store_true', default=False, + help='save predicted images') + parser.add_argument('--color-label-save-path', type=str, + default='./test/color/', + help='path to save label') + parser.add_argument('--merge-label-save-path', type=str, + default='./test/merge/', + help='path to save merged label') + parser.add_argument('--label-save-path', type=str, default='./test/label/', + help='path to save merged label') + parser.add_argument('--merge', action='store_true', default=True, help='merge image and label') + parser.add_argument('--depth', action='store_true', default=False, help='add depth image or not') + + args = parser.parse_args() + args.cuda = not args.no_cuda and torch.cuda.is_available() + if args.cuda: + try: + args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] + except ValueError: + raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') + + return args + +def accuracy(y_true, y_pred, **kwargs): + args = val_args() + _, _, test_loader, num_class = make_data_loader(args, test_data=y_true) + evaluator = Evaluator(num_class) + + tbar = tqdm(test_loader, desc='\r') + for i, (sample, img_path) in enumerate(tbar): + if args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + if args.cuda: + image, target = image.cuda(), target.cuda() + if args.depth: + depth = depth.cuda() + + target[target > evaluator.num_class-1] = 255 + target = target.cpu().numpy() + # Add batch sample into evaluator + evaluator.add_batch(target, y_pred[i]) + + # Test during the training + # Acc = evaluator.Pixel_Accuracy() + CPA = evaluator.Pixel_Accuracy_Class() + mIoU = evaluator.Mean_Intersection_over_Union() + FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() + + print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU)) + return CPA + +if __name__ == '__main__': + model_path = "/tmp/RFNet/" + if not os.path.exists(model_path): + os.makedirs(model_path) + + p1 = Process(target=exp_train, args=(10,)) + p1.start() + p1.join() diff --git a/examples/lifelong_learning/RFNet/dataloaders/__init__.py b/examples/lifelong_learning/RFNet/dataloaders/__init__.py new file mode 100644 index 000000000..2c71f7b10 --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/__init__.py @@ -0,0 +1,119 @@ +from dataloaders.datasets import cityscapes, citylostfound, cityrand, target, xrlab, e1, mapillary +from torch.utils.data import DataLoader + +def make_data_loader(args, train_data=None, valid_data=None, test_data=None, **kwargs): + + if args.dataset == 'cityscapes': + if train_data is not None: + train_set = cityscapes.CityscapesSegmentation(args, data=train_data, split='train') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + else: + train_loader, num_class = None, cityscapes.CityscapesSegmentation.NUM_CLASSES + + if valid_data is not None: + val_set = cityscapes.CityscapesSegmentation(args, data=valid_data, split='val') + num_class = val_set.NUM_CLASSES + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + else: + val_loader, num_class = None, cityscapes.CityscapesSegmentation.NUM_CLASSES + + if test_data is not None: + test_set = cityscapes.CityscapesSegmentation(args, data=test_data, split='test') + num_class = test_set.NUM_CLASSES + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + else: + test_loader, num_class = None, cityscapes.CityscapesSegmentation.NUM_CLASSES + + # custom_set = cityscapes.CityscapesSegmentation(args, split='custom_resize') + # custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + # return train_loader, val_loader, test_loader, custom_loader, num_class + return train_loader, val_loader, test_loader, num_class + + if args.dataset == 'citylostfound': + if args.depth: + train_set = citylostfound.CitylostfoundSegmentation(args, split='train') + val_set = citylostfound.CitylostfoundSegmentation(args, split='val') + test_set = citylostfound.CitylostfoundSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + else: + train_set = citylostfound.CitylostfoundSegmentation_rgb(args, split='train') + val_set = citylostfound.CitylostfoundSegmentation_rgb(args, split='val') + test_set = citylostfound.CitylostfoundSegmentation_rgb(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, num_class + if args.dataset == 'cityrand': + train_set = cityrand.CityscapesSegmentation(args, split='train') + val_set = cityrand.CityscapesSegmentation(args, split='val') + test_set = cityrand.CityscapesSegmentation(args, split='test') + custom_set = cityrand.CityscapesSegmentation(args, split='custom_resize') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + if args.dataset == 'target': + train_set = target.CityscapesSegmentation(args, split='train') + val_set = target.CityscapesSegmentation(args, split='val') + test_set = target.CityscapesSegmentation(args, split='test') + custom_set = target.CityscapesSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + if args.dataset == 'xrlab': + train_set = xrlab.CityscapesSegmentation(args, split='train') + val_set = xrlab.CityscapesSegmentation(args, split='val') + test_set = xrlab.CityscapesSegmentation(args, split='test') + custom_set = xrlab.CityscapesSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + if args.dataset == 'e1': + train_set = e1.CityscapesSegmentation(args, split='train') + val_set = e1.CityscapesSegmentation(args, split='val') + test_set = e1.CityscapesSegmentation(args, split='test') + custom_set = e1.CityscapesSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + if args.dataset == 'mapillary': + train_set = mapillary.CityscapesSegmentation(args, split='train') + val_set = mapillary.CityscapesSegmentation(args, split='val') + test_set = mapillary.CityscapesSegmentation(args, split='test') + custom_set = mapillary.CityscapesSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + else: + raise NotImplementedError + diff --git a/examples/lifelong_learning/RFNet/dataloaders/custom_transforms.py b/examples/lifelong_learning/RFNet/dataloaders/custom_transforms.py new file mode 100644 index 000000000..d63f200a0 --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/custom_transforms.py @@ -0,0 +1,240 @@ +import torch +import random +import numpy as np + +from PIL import Image, ImageOps, ImageFilter + +class Normalize(object): + """Normalize a tensor image with mean and standard deviation. + Args: + mean (tuple): means for each channel. + std (tuple): standard deviations for each channel. + """ + def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): + self.mean = mean + self.std = std + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + img = np.array(img).astype(np.float32) + depth = np.array(depth).astype(np.float32) + mask = np.array(mask).astype(np.float32) + img /= 255.0 + img -= self.mean + img /= self.std + + # mean and std for original depth images + mean_depth = 0.12176 + std_depth = 0.09752 + + depth /= 255.0 + depth -= mean_depth + depth /= std_depth + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class ToTensor(object): + """Convert Image object in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + img = np.array(img).astype(np.float32).transpose((2, 0, 1)) + depth = np.array(depth).astype(np.float32) + mask = np.array(mask).astype(np.float32) + + img = torch.from_numpy(img).float() + depth = torch.from_numpy(depth).float() + mask = torch.from_numpy(mask).float() + + return {'image': img, + 'depth': depth, + 'label': mask} + +class CropBlackArea(object): + """ + crop black area for depth image + """ + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + width, height = img.size + left = 140 + top = 30 + right = 2030 + bottom = 900 + # crop + img = img.crop((left, top, right, bottom)) + depth = depth.crop((left, top, right, bottom)) + mask = mask.crop((left, top, right, bottom)) + # resize + img = img.resize((width,height), Image.BILINEAR) + depth = depth.resize((width,height), Image.BILINEAR) + mask = mask.resize((width,height), Image.NEAREST) + # img = img.resize((512,1024), Image.BILINEAR) + # depth = depth.resize((512,1024), Image.BILINEAR) + # mask = mask.resize((512,1024), Image.NEAREST) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class RandomHorizontalFlip(object): + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + depth = depth.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class RandomRotate(object): + def __init__(self, degree): + self.degree = degree + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + rotate_degree = random.uniform(-1*self.degree, self.degree) + img = img.rotate(rotate_degree, Image.BILINEAR) + depth = depth.rotate(rotate_degree, Image.BILINEAR) + mask = mask.rotate(rotate_degree, Image.NEAREST) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class RandomGaussianBlur(object): + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur( + radius=random.random())) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class RandomScaleCrop(object): + def __init__(self, base_size, crop_size, fill=0): + self.base_size = base_size + self.crop_size = crop_size + self.fill = fill + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + # random scale (short edge) + short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + depth = depth.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < self.crop_size: + padh = self.crop_size - oh if oh < self.crop_size else 0 + padw = self.crop_size - ow if ow < self.crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + depth = ImageOps.expand(depth, border=(0, 0, padw, padh), fill=0) # depth多余的部分填0 + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - self.crop_size) + y1 = random.randint(0, h - self.crop_size) + img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + depth = depth.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class FixScaleCrop(object): + def __init__(self, crop_size): + self.crop_size = crop_size + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + w, h = img.size + if w > h: + oh = self.crop_size + ow = int(1.0 * w * oh / h) + else: + ow = self.crop_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + depth = depth.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - self.crop_size) / 2.)) + y1 = int(round((h - self.crop_size) / 2.)) + img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + depth = depth.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + + return {'image': img, + 'depth': depth, + 'label': mask} + +class FixedResize(object): + def __init__(self, size): + self.size = (size, size) # size: (h, w) + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + + assert img.size == depth.size == mask.size + + img = img.resize(self.size, Image.BILINEAR) + depth = depth.resize(self.size, Image.BILINEAR) + mask = mask.resize(self.size, Image.NEAREST) + + return {'image': img, + 'depth': depth, + 'label': mask} + +class Relabel(object): + def __init__(self, olabel, nlabel): # change trainid label from olabel to nlabel + self.olabel = olabel + self.nlabel = nlabel + + def __call__(self, tensor): + # assert (isinstance(tensor, torch.LongTensor) or isinstance(tensor, + # torch.ByteTensor)), 'tensor needs to be LongTensor' + tensor[tensor == self.olabel] = self.nlabel + return tensor \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/dataloaders/custom_transforms_rgb.py b/examples/lifelong_learning/RFNet/dataloaders/custom_transforms_rgb.py new file mode 100644 index 000000000..e04ef5a38 --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/custom_transforms_rgb.py @@ -0,0 +1,230 @@ +import torch +import random +import numpy as np + +from PIL import Image, ImageOps, ImageFilter + +class Normalize(object): + """Normalize a tensor image with mean and standard deviation. + Args: + mean (tuple): means for each channel. + std (tuple): standard deviations for each channel. + """ + def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): + self.mean = mean + self.std = std + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + img = np.array(img).astype(np.float32) + mask = np.array(mask).astype(np.float32) + img /= 255.0 + img -= self.mean + img /= self.std + + return {'image': img, + 'label': mask} + + +class Normalize_test(object): + def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): + self.mean = mean + self.std = std + + def __call__(self, sample): + img = sample + img = np.array(img).astype(np.float32) + img /= 255.0 + img -= self.mean + img /= self.std + + return img + + +class ToTensor(object): + """Convert Image object in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'] + mask = sample['label'] + img = np.array(img).astype(np.float32).transpose((2, 0, 1)) + mask = np.array(mask).astype(np.float32) + + img = torch.from_numpy(img).float() + mask = torch.from_numpy(mask).float() + + return {'image': img, + 'label': mask} + +class CropBlackArea(object): + """ + crop black area for depth image + """ + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + width, height = img.size + left = 140 + top = 30 + right = 2030 + bottom = 900 + # crop + img = img.crop((left, top, right, bottom)) + mask = mask.crop((left, top, right, bottom)) + # resize + img = img.resize((width,height), Image.BILINEAR) + mask = mask.resize((width,height), Image.NEAREST) + # img = img.resize((512,1024), Image.BILINEAR) + # mask = mask.resize((512,1024), Image.NEAREST) + print(img.size) + + return {'image': img, + 'label': mask} + +class ToTensor_test(object): + """Convert Image object in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample + img = np.array(img).astype(np.float32).transpose((2, 0, 1)) + + img = torch.from_numpy(img).float() + + return img + + +class RandomHorizontalFlip(object): + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + + return {'image': img, + 'label': mask} + + +class RandomRotate(object): + def __init__(self, degree): + self.degree = degree + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + rotate_degree = random.uniform(-1*self.degree, self.degree) + img = img.rotate(rotate_degree, Image.BILINEAR) + mask = mask.rotate(rotate_degree, Image.NEAREST) + + return {'image': img, + 'label': mask} + + +class RandomGaussianBlur(object): + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur( + radius=random.random())) + + return {'image': img, + 'label': mask} + + +class RandomScaleCrop(object): + def __init__(self, base_size, crop_size, fill=0): + self.base_size = base_size + self.crop_size = crop_size + self.fill = fill + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + # random scale (short edge) + short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < self.crop_size: + padh = self.crop_size - oh if oh < self.crop_size else 0 + padw = self.crop_size - ow if ow < self.crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - self.crop_size) + y1 = random.randint(0, h - self.crop_size) + img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + + return {'image': img, + 'label': mask} + + +class FixScaleCrop(object): + def __init__(self, crop_size): + self.crop_size = crop_size + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + w, h = img.size + if w > h: + oh = self.crop_size + ow = int(1.0 * w * oh / h) + else: + ow = self.crop_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - self.crop_size) / 2.)) + y1 = int(round((h - self.crop_size) / 2.)) + img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + + return {'image': img, + 'label': mask} + +class FixedResize(object): + def __init__(self, size): + self.size = (size, size) # size: (h, w) + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + + assert img.size == mask.size + + img = img.resize(self.size, Image.BILINEAR) + mask = mask.resize(self.size, Image.NEAREST) + + return {'image': img, + 'label': mask} + +class Relabel(object): + def __init__(self, olabel, nlabel): # change trainid label from olabel to nlabel + self.olabel = olabel + self.nlabel = nlabel + + def __call__(self, tensor): + # assert (isinstance(tensor, torch.LongTensor) or isinstance(tensor, + # torch.ByteTensor)), 'tensor needs to be LongTensor' + tensor[tensor == self.olabel] = self.nlabel + return tensor \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/__init__.py b/examples/lifelong_learning/RFNet/dataloaders/datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/citylostfound.py b/examples/lifelong_learning/RFNet/dataloaders/datasets/citylostfound.py new file mode 100644 index 000000000..6ffd0a4b3 --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/datasets/citylostfound.py @@ -0,0 +1,276 @@ +import os +import numpy as np +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr +from dataloaders import custom_transforms_rgb as tr_rgb + +class CitylostfoundSegmentation(data.Dataset): + NUM_CLASSES = 20 + + def __init__(self, args, root=Path.db_root_dir('citylostfound'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root,'disparity',self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix= '.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix= '.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, + suffix='labelTrainIds.png') + self.labels[split].sort() + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + lbl_path = self.labels[self.split][index].rstrip() + + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) + if self.split == 'train': + if index < 1036: # lostandfound + _tmp = self.relabel_lostandfound(_tmp) + else: # cityscapes + pass + elif self.split == 'val': + if index < 1203: # lostandfound + _tmp = self.relabel_lostandfound(_tmp) + else: # cityscapes + pass + _target = Image.fromarray(_tmp) + + sample = {'image': _img, 'depth': _depth, 'label': _target} + + # data augment + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample) + + + def relabel_lostandfound(self, input): + input = tr.Relabel(0, self.ignore_index)(input) # background->255 ignore + input = tr.Relabel(1, 0)(input) # road 1->0 + # input = Relabel(255, 20)(input) # unlabel 20 + input = tr.Relabel(2, 19)(input) # obstacle 19 + return input + + def recursive_glob(self, rootdir='.', suffix=None): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + if isinstance(suffix, str): + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + elif isinstance(suffix, list): + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for x in suffix for filename in filenames if filename.startswith(x)] + + + def transform_tr(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + # tr.CropBlackArea(), + tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + +class CitylostfoundSegmentation_rgb(data.Dataset): + NUM_CLASSES = 19 + + def __init__(self, args, root=Path.db_root_dir('citylostfound'), split="train"): + + self.root = root + self.split = split + self.args = args + self.files = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.files[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='labelTrainIds.png') + self.labels[split].sort() + + self.ignore_index = 255 + + if not self.files[split]: + raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) + + print("Found %d %s images" % (len(self.files[split]), split)) + + def __len__(self): + return len(self.files[self.split]) + + def __getitem__(self, index): + + img_path = self.files[self.split][index].rstrip() + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) + if self.split == 'train': + if index < 1036: # lostandfound + _tmp = self.relabel_lostandfound(_tmp) + else: # cityscapes + pass + elif self.split == 'val': + if index < 1203: # lostandfound + _tmp = self.relabel_lostandfound(_tmp) + else: # cityscapes + pass + _target = Image.fromarray(_tmp) + + sample = {'image': _img, 'label': _target} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample) + + + def relabel_lostandfound(self, input): + input = tr.Relabel(0, self.ignore_index)(input) + input = tr.Relabel(1, 0)(input) # road 1->0 + input = tr.Relabel(2, 19)(input) # obstacle 19 + return input + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr_rgb.CropBlackArea(), + tr_rgb.RandomHorizontalFlip(), + tr_rgb.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr_rgb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr_rgb.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr_rgb.CropBlackArea(), + tr_rgb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr_rgb.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr_rgb.FixedResize(size=self.args.crop_size), + tr_rgb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr_rgb.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CitylostfoundSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/cityrand.py b/examples/lifelong_learning/RFNet/dataloaders/datasets/cityrand.py new file mode 100644 index 000000000..74eddb672 --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/datasets/cityrand.py @@ -0,0 +1,151 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 19 + + def __init__(self, args, root=Path.db_root_dir('cityrand'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='TrainIds.png') + self.labels[split].sort() + + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + #tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/cityscapes.py b/examples/lifelong_learning/RFNet/dataloaders/datasets/cityscapes.py new file mode 100644 index 000000000..c491df29a --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/datasets/cityscapes.py @@ -0,0 +1,151 @@ +import os +import numpy as np +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 24 # 25 + + def __init__(self, args, root=Path.db_root_dir('cityscapes'), data=None, split="train"): + + # self.root = root + self.root = "/home/lsq/Dataset/" + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.disparities_base = os.path.join(self.root, self.split, "depth", "cityscapes_real") + self.images[split] = [img[0] for img in data.x] if hasattr(data, "x") else data + + if hasattr(data, "x") and len(data.x[0]) == 1: + # TODO: fit the case that depth images don't exist. + self.disparities[split] = self.images[split] + elif hasattr(data, "x") and len(data.x[0]) == 2: + self.disparities[split] = [img[1] for img in data.x] + else: + self.disparities[split] = data + + self.labels[split] = data.y if hasattr(data, "y") else data + + self.ignore_index = 255 + + if len(self.images[split]) == 0: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if len(self.disparities[split]) == 0: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + #tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/e1.py b/examples/lifelong_learning/RFNet/dataloaders/datasets/e1.py new file mode 100644 index 000000000..40e06e981 --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/datasets/e1.py @@ -0,0 +1,151 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 24 + + def __init__(self, args, root=Path.db_root_dir('e1'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='.png') + self.labels[split].sort() + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + #tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + #tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/mapillary.py b/examples/lifelong_learning/RFNet/dataloaders/datasets/mapillary.py new file mode 100644 index 000000000..d665649bc --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/datasets/mapillary.py @@ -0,0 +1,152 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 24 + + def __init__(self, args, root=Path.db_root_dir('mapillary'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='.png') + self.labels[split].sort() + + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 768 + args.crop_size = 768 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/target.py b/examples/lifelong_learning/RFNet/dataloaders/datasets/target.py new file mode 100644 index 000000000..739e85f83 --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/datasets/target.py @@ -0,0 +1,152 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 24 + + def __init__(self, args, root=Path.db_root_dir('target'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='TrainIds.png') + self.labels[split].sort() + + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/temp.txt b/examples/lifelong_learning/RFNet/dataloaders/datasets/temp.txt new file mode 100644 index 000000000..3c81afefe --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/datasets/temp.txt @@ -0,0 +1,38 @@ +for i in range(len(nam_label)): + img_label = cv2.imread(label_ori_path+nam_label[i], -1)[:,:,2] + img_label_temp = img_label.copy() + img_label_temp[img_label == 0] = 22 + img_label_temp[img_label == 1] = 10 + img_label_temp[img_label == 2] = 2 + img_label_temp[img_label == 3] = 0 + img_label_temp[img_label == 4] = 1 + img_label_temp[img_label == 5] = 4 + img_label_temp[img_label == 6] = 8 + img_label_temp[img_label == 7] = 5 + img_label_temp[img_label == 8] = 13 + img_label_temp[img_label == 9] = 7 + img_label_temp[img_label == 10] = 11 + img_label_temp[img_label == 11] = 18 + img_label_temp[img_label == 12] = 17 + img_label_temp[img_label == 13] = 21 + img_label_temp[img_label == 14] = 20 + img_label_temp[img_label == 15] = 6 + img_label_temp[img_label == 16] = 9 + img_label_temp[img_label == 17] = 12 + img_label_temp[img_label == 18] = 14 + img_label_temp[img_label == 19] = 15 + img_label_temp[img_label == 20] = 16 + img_label_temp[img_label == 21] = 3 + img_label_temp[img_label == 22] = 19 + #print(img_label) + #img_label[img_label == 0] = 10 + #img_label[img_label == 6] = 0 + #img_label[img_label == 5] = 11 + #img_label[img_label == 1] = 5 + #img_label[img_label == 2] = 1 + #img_label[img_label == 4] = 9 + #img_label[img_label == 3] = 4 + #img_label[img_label == 7] = 8 + #img_label[img_label == 11] = 2 + img_resize_lab = cv2.resize(img_label_temp, (2048,1024), interpolation=cv2.INTER_NEAREST) + cv2.imwrite(label_save_path+str(i)+'TrainIds.png', img_resize_lab.astype(np.uint16)) \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/dataloaders/datasets/xrlab.py b/examples/lifelong_learning/RFNet/dataloaders/datasets/xrlab.py new file mode 100644 index 000000000..4b261fcdd --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/datasets/xrlab.py @@ -0,0 +1,152 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 25 + + def __init__(self, args, root=Path.db_root_dir('xrlab'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='.png') + self.labels[split].sort() + + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/lifelong_learning/RFNet/dataloaders/utils.py b/examples/lifelong_learning/RFNet/dataloaders/utils.py new file mode 100644 index 000000000..ef572332a --- /dev/null +++ b/examples/lifelong_learning/RFNet/dataloaders/utils.py @@ -0,0 +1,244 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch + +def decode_seg_map_sequence(label_masks, dataset='pascal'): + rgb_masks = [] + for label_mask in label_masks: + rgb_mask = decode_segmap(label_mask, dataset) + rgb_masks.append(rgb_mask) + rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) # change for val + return rgb_masks + + +def decode_segmap(label_mask, dataset, plot=False): + """Decode segmentation class labels into a color image + Args: + label_mask (np.ndarray): an (M,N) array of integer values denoting + the class label at each spatial location. + plot (bool, optional): whether to show the resulting color image + in a figure. + Returns: + (np.ndarray, optional): the resulting decoded color image. + """ + if dataset == 'pascal' or dataset == 'coco': + n_classes = 21 + label_colours = get_pascal_labels() + elif dataset == 'cityscapes': + n_classes = 19 + label_colours = get_cityscapes_labels() + elif dataset == 'target': + n_classes = 24 + label_colours = get_cityscapes_labels() + elif dataset == 'cityrand': + n_classes = 19 + label_colours = get_cityscapes_labels() + elif dataset == 'citylostfound': + n_classes = 20 + label_colours = get_citylostfound_labels() + elif dataset == 'xrlab': + n_classes = 25 + label_colours = get_cityscapes_labels() + elif dataset == 'e1': + n_classes = 24 + label_colours = get_cityscapes_labels() + elif dataset == 'mapillary': + n_classes = 24 + label_colours = get_cityscapes_labels() + else: + raise NotImplementedError + + r = label_mask.copy() + g = label_mask.copy() + b = label_mask.copy() + for ll in range(0, n_classes): + r[label_mask == ll] = label_colours[ll, 0] + g[label_mask == ll] = label_colours[ll, 1] + b[label_mask == ll] = label_colours[ll, 2] + rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # change for val + # rgb = torch.ByteTensor(3, label_mask.shape[0], label_mask.shape[1]).fill_(0) + rgb[:, :, 0] = r / 255.0 + rgb[:, :, 1] = g / 255.0 + rgb[:, :, 2] = b / 255.0 + # r = torch.from_numpy(r) + # g = torch.from_numpy(g) + # b = torch.from_numpy(b) + + rgb[:, :, 0] = r / 255.0 + rgb[:, :, 1] = g / 255.0 + rgb[:, :, 2] = b / 255.0 + if plot: + plt.imshow(rgb) + plt.show() + else: + return rgb + + +def encode_segmap(mask): + """Encode segmentation label images as pascal classes + Args: + mask (np.ndarray): raw segmentation label image of dimension + (M, N, 3), in which the Pascal classes are encoded as colours. + Returns: + (np.ndarray): class map with dimensions (M,N), where the value at + a given location is the integer denoting the class index. + """ + mask = mask.astype(int) + label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) + for ii, label in enumerate(get_pascal_labels()): + label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii + label_mask = label_mask.astype(int) + return label_mask + + +def get_cityscapes_labels(): + return np.array([ + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [0, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + [119, 11, 119], + [128, 64, 64], + [102, 10, 156], + [102, 102, 15], + [10, 102, 156], + [10, 102, 156], + [10, 102, 156], + [10, 102, 156]]) + +def get_citylostfound_labels(): + return np.array([ + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [0, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + [111, 74, 0]]) + + +def get_pascal_labels(): + """Load the mapping that associates pascal classes with label colors + Returns: + np.ndarray with dimensions (21, 3) + """ + return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], + [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], + [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], + [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], + [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], + [0, 64, 128]]) + + +def colormap_bdd(n): + cmap=np.zeros([n, 3]).astype(np.uint8) + cmap[0,:] = np.array([128, 64, 128]) + cmap[1,:] = np.array([244, 35, 232]) + cmap[2,:] = np.array([ 70, 70, 70]) + cmap[3,:] = np.array([102, 102, 156]) + cmap[4,:] = np.array([190, 153, 153]) + cmap[5,:] = np.array([153, 153, 153]) + + cmap[6,:] = np.array([250, 170, 30]) + cmap[7,:] = np.array([220, 220, 0]) + cmap[8,:] = np.array([107, 142, 35]) + cmap[9,:] = np.array([152, 251, 152]) + cmap[10,:]= np.array([70, 130, 180]) + + cmap[11,:]= np.array([220, 20, 60]) + cmap[12,:]= np.array([255, 0, 0]) + cmap[13,:]= np.array([0, 0, 142]) + cmap[14,:]= np.array([0, 0, 70]) + cmap[15,:]= np.array([0, 60, 100]) + + cmap[16,:]= np.array([0, 80, 100]) + cmap[17,:]= np.array([0, 0, 230]) + cmap[18,:]= np.array([119, 11, 32]) + cmap[19,:]= np.array([111, 74, 0]) #多加了一类small obstacle + + return cmap + +def colormap_bdd0(n): + cmap=np.zeros([n, 3]).astype(np.uint8) + cmap[0,:] = np.array([0, 0, 0]) + cmap[1,:] = np.array([70, 130, 180]) + cmap[2,:] = np.array([70, 70, 70]) + cmap[3,:] = np.array([128, 64, 128]) + cmap[4,:] = np.array([244, 35, 232]) + cmap[5,:] = np.array([64, 64, 128]) + + cmap[6,:] = np.array([107, 142, 35]) + cmap[7,:] = np.array([153, 153, 153]) + cmap[8,:] = np.array([0, 0, 142]) + cmap[9,:] = np.array([220, 220, 0]) + cmap[10,:]= np.array([220, 20, 60]) + + cmap[11,:]= np.array([119, 11, 32]) + cmap[12,:]= np.array([0, 0, 230]) + cmap[13,:]= np.array([250, 170, 160]) + cmap[14,:]= np.array([128, 64, 64]) + cmap[15,:]= np.array([250, 170, 30]) + + cmap[16,:]= np.array([152, 251, 152]) + cmap[17,:]= np.array([255, 0, 0]) + cmap[18,:]= np.array([0, 0, 70]) + cmap[19,:]= np.array([0, 60, 100]) #small obstacle + cmap[20,:]= np.array([0, 80, 100]) + cmap[21,:]= np.array([102, 102, 156]) + cmap[22,:]= np.array([102, 102, 156]) + + return cmap + +class Colorize: + + def __init__(self, n=24): # n = nClasses + # self.cmap = colormap(256) + self.cmap = colormap_bdd(256) + self.cmap[n] = self.cmap[-1] + self.cmap = torch.from_numpy(self.cmap[:n]) + + def __call__(self, gray_image): + size = gray_image.size() + # print(size) + color_images = torch.ByteTensor(size[0], 3, size[1], size[2]).fill_(0) + # color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) + + # for label in range(1, len(self.cmap)): + for i in range(color_images.shape[0]): + for label in range(0, len(self.cmap)): + mask = gray_image[0] == label + # mask = gray_image == label + + color_images[i][0][mask] = self.cmap[label][0] + color_images[i][1][mask] = self.cmap[label][1] + color_images[i][2][mask] = self.cmap[label][2] + + return color_images diff --git a/examples/lifelong_learning/RFNet/eval.py b/examples/lifelong_learning/RFNet/eval.py new file mode 100644 index 000000000..482315c92 --- /dev/null +++ b/examples/lifelong_learning/RFNet/eval.py @@ -0,0 +1,247 @@ +import argparse +import os +import numpy as np +from tqdm import tqdm +import time +import torch +from torchvision.transforms import ToPILImage +from PIL import Image + +from dataloaders import make_data_loader +from dataloaders.utils import decode_seg_map_sequence, Colorize +from utils.metrics import Evaluator +from models.rfnet import RFNet +from models import rfnet_for_unseen +from models.resnet.resnet_single_scale_single_attention import * +from models.resnet import resnet_single_scale_single_attention_unseen +import torch.backends.cudnn as cudnn + +class Validator(object): + def __init__(self, args, data=None, unseen_detection=False): + self.args = args + self.time_train = [] + self.num_class = args.num_class + + # Define Dataloader + kwargs = {'num_workers': args.workers, 'pin_memory': False} + # _, self.val_loader, _, self.custom_loader, self.num_class = make_data_loader(args, **kwargs) + _, _, self.test_loader, _ = make_data_loader(args, test_data=data, **kwargs) + print('un_classes:'+str(self.num_class)) + + # Define evaluator + self.evaluator = Evaluator(self.num_class) + + # Define network + if unseen_detection: + self.resnet = resnet_single_scale_single_attention_unseen.\ + resnet18(pretrained=False, efficient=False, use_bn=True) + self.model = rfnet_for_unseen.RFNet(self.resnet, num_classes=self.num_class, use_bn=True) + else: + self.resnet = resnet18(pretrained=False, efficient=False, use_bn=True) + self.model = RFNet(self.resnet, num_classes=self.num_class, use_bn=True) + + if args.cuda: + self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) + self.model = self.model.cuda() + cudnn.benchmark = True # accelarate speed + print('Model loaded successfully!') + + # Load weights + assert os.path.exists(args.weight_path), 'weight-path:{} doesn\'t exit!'.format(args.weight_path) + self.new_state_dict = torch.load(args.weight_path, map_location=torch.device("cpu")) + self.model = load_my_state_dict(self.model, self.new_state_dict['state_dict']) + + def validate(self): + self.model.eval() + self.evaluator.reset() + # tbar = tqdm(self.test_loader, desc='\r') + predictions = [] + for sample, image_name in self.test_loader: + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + # spec = time.time() + image, target = sample['image'], sample['label'] + + if self.args.cuda: + image = image.cuda() + if self.args.depth: + depth = depth.cuda() + + with torch.no_grad(): + if self.args.depth: + output = self.model(image, depth) + else: + output = self.model(image) + + if self.args.cuda: + torch.cuda.synchronize() + + pred = output.data.cpu().numpy() + # todo + pred = np.argmax(pred, axis=1) + predictions.append(pred) + + if not self.args.save_predicted_image: + continue + + pre_colors = Colorize()(torch.max(output, 1)[1].detach().cpu().byte()) + pre_labels = torch.max(output, 1)[1].detach().cpu().byte() + print(pre_labels.shape) + # save + for i in range(pre_colors.shape[0]): + print(image_name[0]) + + if not image_name[0]: + img_name = "test.png" + else: + img_name = os.path.basename(image_name[0]) + + color_label_name = os.path.join(self.args.color_label_save_path, img_name) + label_name = os.path.join(self.args.label_save_path, img_name) + merge_label_name = os.path.join(self.args.merge_label_save_path, img_name) + + os.makedirs(os.path.dirname(color_label_name), exist_ok=True) + os.makedirs(os.path.dirname(merge_label_name), exist_ok=True) + os.makedirs(os.path.dirname(label_name), exist_ok=True) + + pre_color_image = ToPILImage()(pre_colors[i]) # pre_colors.dtype = float64 + pre_color_image.save(color_label_name) + + pre_label_image = ToPILImage()(pre_labels[i]) + pre_label_image.save(label_name) + + if (self.args.merge): + image_merge(image[i], pre_color_image, merge_label_name) + print('save image: {}'.format(merge_label_name)) + + return predictions + + def task_divide(self): + seen_task_samples, unseen_task_samples = [], [] + self.model.eval() + self.evaluator.reset() + tbar = tqdm(self.test_loader, desc='\r') + for i, (sample, image_name) in enumerate(tbar): + + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + if self.args.cuda: + image = image.cuda() + if self.args.depth: + depth = depth.cuda() + start_time = time.time() + with torch.no_grad(): + if self.args.depth: + output_, output, _ = self.model(image, depth) + else: + output_, output, _ = self.model(image) + if self.args.cuda: + torch.cuda.synchronize() + if i != 0: + fwt = time.time() - start_time + self.time_train.append(fwt) + print("Forward time per img (bath size=%d): %.3f (Mean: %.3f)" % ( + self.args.val_batch_size, fwt / self.args.val_batch_size, + sum(self.time_train) / len(self.time_train) / self.args.val_batch_size)) + time.sleep(0.1) # to avoid overheating the GPU too much + + # pred colorize + pre_colors = Colorize()(torch.max(output, 1)[1].detach().cpu().byte()) + pre_labels = torch.max(output, 1)[1].detach().cpu().byte() + for i in range(pre_colors.shape[0]): + task_sample = dict() + task_sample.update(image=sample["image"][i]) + task_sample.update(label=sample["label"][i]) + if self.args.depth: + task_sample.update(depth=sample["depth"][i]) + + if torch.max(pre_labels) == output.shape[1] - 1: + unseen_task_samples.append((task_sample, image_name[i])) + else: + seen_task_samples.append((task_sample, image_name[i])) + + return seen_task_samples, unseen_task_samples + +def image_merge(image, label, save_name): + image = ToPILImage()(image.detach().cpu().byte()) + # width, height = image.size + left = 140 + top = 30 + right = 2030 + bottom = 900 + # crop + image = image.crop((left, top, right, bottom)) + # resize + image = image.resize(label.size, Image.BILINEAR) + + image = image.convert('RGBA') + label = label.convert('RGBA') + image = Image.blend(image, label, 0.6) + image.save(save_name) + +def load_my_state_dict(model, state_dict): # custom function to load model when not all dict elements + own_state = model.state_dict() + for name, param in state_dict.items(): + if name not in own_state: + print('{} not in model_state'.format(name)) + continue + else: + own_state[name].copy_(param) + + return model + +def main(): + parser = argparse.ArgumentParser(description="PyTorch RFNet validation") + parser.add_argument('--dataset', type=str, default='cityscapes', + choices=['citylostfound', 'cityscapes', 'xrlab', 'mapillary'], + help='dataset name (default: cityscapes)') + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='dataloader threads') + parser.add_argument('--base-size', type=int, default=1024, + help='base image size') + parser.add_argument('--crop-size', type=int, default=768, + help='crop image size') + parser.add_argument('--batch-size', type=int, default=6, + help='batch size for training') + parser.add_argument('--val-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + validating (default: auto)') + parser.add_argument('--test-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--no-cuda', action='store_true', default= + False, help='disables CUDA training') + parser.add_argument('--gpu-ids', type=str, default='0', + help='use which gpu to train, must be a \ + comma-separated list of integers only (default=0)') + parser.add_argument('--checkname', type=str, default=None, + help='set the checkpoint name') + parser.add_argument('--weight-path', type=str, default=None, + help='enter your path of the weight') + parser.add_argument('--color-label-save-path', type=str, default='D:/m0063/project/RFNet-master/test/color/', + help='path to save label') + parser.add_argument('--merge-label-save-path', type=str, default='D:/m0063/project/RFNet-master/test/merge/', + help='path to save merged label') + parser.add_argument('--label-save-path', type=str, default='D:/m0063/project/RFNet-master/test/label/', + help='path to save merged label') + parser.add_argument('--merge', action='store_true', default=False, help='merge image and label') + parser.add_argument('--depth', action='store_true', default=False, help='add depth image or not') + + + args = parser.parse_args() + args.cuda = not args.no_cuda and torch.cuda.is_available() + if args.cuda: + try: + args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] + except ValueError: + raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') + + validator = Validator(args) + validator.validate() + + +if __name__ == "__main__": + main() diff --git a/examples/lifelong_learning/RFNet/models/replicate.py b/examples/lifelong_learning/RFNet/models/replicate.py new file mode 100644 index 000000000..3734266ea --- /dev/null +++ b/examples/lifelong_learning/RFNet/models/replicate.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/models/resnet/__init__.py b/examples/lifelong_learning/RFNet/models/resnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/lifelong_learning/RFNet/models/resnet/resnet_single_scale_single_attention.py b/examples/lifelong_learning/RFNet/models/resnet/resnet_single_scale_single_attention.py new file mode 100644 index 000000000..63d819910 --- /dev/null +++ b/examples/lifelong_learning/RFNet/models/resnet/resnet_single_scale_single_attention.py @@ -0,0 +1,391 @@ +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from itertools import chain +import torch.utils.checkpoint as cp + +from ..util import _Upsample, SpatialPyramidPooling + +__all__ = ['ResNet', 'resnet18'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) + +def _bn_function_factory(conv, norm, relu=None): + """return a conv-bn-relu function""" + def bn_function(x): + x = conv(x) + if norm is not None: + x = norm(x) + if relu is not None: + x = relu(x) + return x + + return bn_function + + +def do_efficient_fwd(block, x, efficient): + if efficient and x.requires_grad: + return cp.checkpoint(block, x) + else: + return block(x) + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + + # reshape + x = x.view(batchsize, groups, + channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + x = x.view(batchsize, -1, height, width) + + return x + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=False, use_bn=True): + super(BasicBlock, self).__init__() + self.use_bn = use_bn + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) if self.use_bn else None + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) if self.use_bn else None + self.downsample = downsample + self.stride = stride + self.efficient = efficient + + def forward(self, x): + residual = x + + bn_1 = _bn_function_factory(self.conv1, self.bn1, self.relu) + bn_2 = _bn_function_factory(self.conv2, self.bn2) + + out = do_efficient_fwd(bn_1, x, self.efficient) + out = do_efficient_fwd(bn_2, out, self.efficient) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + relu = self.relu(out) + + return relu, out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=True, use_bn=True): + super(Bottleneck, self).__init__() + self.use_bn = use_bn + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) if self.use_bn else None + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) if self.use_bn else None + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) if self.use_bn else None + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.efficient = efficient + + def forward(self, x): + residual = x + + bn_1 = _bn_function_factory(self.conv1, self.bn1, self.relu) + bn_2 = _bn_function_factory(self.conv2, self.bn2, self.relu) + bn_3 = _bn_function_factory(self.conv3, self.bn3, self.relu) + + out = do_efficient_fwd(bn_1, x, self.efficient) + out = do_efficient_fwd(bn_2, out, self.efficient) + out = do_efficient_fwd(bn_3, out, self.efficient) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + relu = self.relu(out) + + return relu, out + + +class ResNet(nn.Module): + def __init__(self, block, layers, *, num_features=128, k_up=3, efficient=True, use_bn=True, + spp_grids=(8, 4, 2, 1), spp_square_grid=False, **kwargs): + super(ResNet, self).__init__() + self.inplanes = 64 + self.efficient = efficient + self.use_bn = use_bn + + # rgb branch + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) if self.use_bn else lambda x: x + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # depth branch + self.conv1_d = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False) + self.bn1_d = nn.BatchNorm2d(64) if self.use_bn else lambda x: x + self.relu_d = nn.ReLU(inplace=True) + self.maxpool_d = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + upsamples = [] + # 修改 _make_layer_rgb _make_layer + self.layer1 = self._make_layer_rgb(block, 64, 64, layers[0]) + self.layer1_d = self._make_layer_d(block, 64, 64, layers[0]) + self.attention_1 = self.attention(64) + self.attention_1_d = self.attention(64) + upsamples += [_Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up)] # num_maps_in, skip_maps_in, num_maps_out, k: kernel size of blend conv + + self.layer2 = self._make_layer_rgb(block, 64, 128, layers[1], stride=2) + self.layer2_d = self._make_layer_d(block, 64, 128, layers[1], stride=2) + self.attention_2 = self.attention(128) + self.attention_2_d = self.attention(128) + upsamples += [_Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up)] + + self.layer3 = self._make_layer_rgb(block, 128, 256, layers[2], stride=2) + self.layer3_d = self._make_layer_d(block, 128, 256, layers[2], stride=2) + self.attention_3 = self.attention(256) + self.attention_3_d = self.attention(256) + upsamples += [_Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up)] + + self.layer4 = self._make_layer_rgb(block, 256, 512, layers[3], stride=2) + self.layer4_d = self._make_layer_d(block, 256, 512, layers[3], stride=2) + self.attention_4 = self.attention(512) + self.attention_4_d = self.attention(512) + + self.fine_tune = [self.conv1, self.maxpool, self.layer1, self.layer2, self.layer3, self.layer4, + self.conv1_d, self.maxpool_d, self.layer1_d, self.layer2_d, self.layer3_d, self.layer4_d] + if self.use_bn: + self.fine_tune += [self.bn1, self.bn1_d, self.attention_1, self.attention_1_d, self.attention_2, self.attention_2_d, + self.attention_3, self.attention_3_d, self.attention_4, self.attention_4_d] + + num_levels = 3 + self.spp_size = num_features + bt_size = self.spp_size + + level_size = self.spp_size // num_levels + + self.spp = SpatialPyramidPooling(self.inplanes, num_levels, bt_size=bt_size, level_size=level_size, + out_size=self.spp_size, grids=spp_grids, square_grid=spp_square_grid, + bn_momentum=0.01 / 2, use_bn=self.use_bn) + self.upsample = nn.ModuleList(list(reversed(upsamples))) + + self.random_init = []#[ self.spp, self.upsample] + self.fine_tune += [self.spp, self.upsample] + + self.num_features = num_features + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer_rgb(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + layers = [nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False)] + if self.use_bn: + layers += [nn.BatchNorm2d(planes * block.expansion)] + downsample = nn.Sequential(*layers) + layers = [block(inplanes, planes, stride, downsample, efficient=self.efficient, use_bn=self.use_bn)] + inplanes = planes * block.expansion + for i in range(1, blocks): + layers += [block(inplanes, planes, efficient=self.efficient, use_bn=self.use_bn)] + + return nn.Sequential(*layers) + + def _make_layer_d(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + layers = [nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False)] + if self.use_bn: + layers += [nn.BatchNorm2d(planes * block.expansion)] + downsample = nn.Sequential(*layers) + layers = [block(inplanes, planes, stride, downsample, efficient=self.efficient, use_bn=self.use_bn)] + inplanes = planes * block.expansion + self.inplanes = inplanes + for i in range(1, blocks): + layers += [block(inplanes, planes, efficient=self.efficient, use_bn=self.use_bn)] + + return nn.Sequential(*layers) + + def channel_attention(self, rgb_skip, depth_skip, attention): + assert rgb_skip.shape == depth_skip.shape, 'rgb skip shape:{} != depth skip shape:{}'.format(rgb_skip.shape, depth_skip.shape) + # single_attenton + rgb_attention = attention(rgb_skip) + depth_attention = attention(depth_skip) + rgb_after_attention = torch.mul(rgb_skip, rgb_attention) + depth_after_attention = torch.mul(depth_skip, depth_attention) + skip_after_attention = rgb_after_attention + depth_after_attention + return skip_after_attention + + def attention(self, num_channels): + pool_attention = nn.AdaptiveAvgPool2d(1) + conv_attention = nn.Conv2d(num_channels, num_channels, kernel_size=1) + activate = nn.Sigmoid() + + return nn.Sequential(pool_attention, conv_attention, activate) + + + def random_init_params(self): + return chain(*[f.parameters() for f in self.random_init]) + + def fine_tune_params(self): + return chain(*[f.parameters() for f in self.fine_tune]) + + def forward_resblock(self, x, layers): + skip = None + for l in layers: + x = l(x) + if isinstance(x, tuple): + x, skip = x + return x, skip + + def forward_down(self, rgb): + x = self.conv1(rgb) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + features = [] + x, skip = self.forward_resblock(x, self.layer1) + features += [skip] + x, skip = self.forward_resblock(x, self.layer2) + features += [skip] + x, skip = self.forward_resblock(x, self.layer3) + features += [skip] + x, skip = self.forward_resblock(x.detach(), self.layer4) + features += [self.spp.forward(skip)] + return features + + def forward_down_fusion(self, rgb, depth): + x = self.conv1(rgb) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + depth = depth.unsqueeze(1) + y = self.conv1_d(depth) + y = self.bn1_d(y) + y = self.relu_d(y) + y = self.maxpool_d(y) + + features = [] + # block 1 + x, skip_rgb = self.forward_resblock(x.detach(), self.layer1) + y, skip_depth = self.forward_resblock(y.detach(), self.layer1_d) + x_attention = self.attention_1(x) + y_attention = self.attention_1_d(y) + x = torch.mul(x, x_attention) + y = torch.mul(y, y_attention) + x = x + y + features += [skip_rgb.detach()] + # block 2 + x, skip_rgb = self.forward_resblock(x.detach(), self.layer2) + y, skip_depth = self.forward_resblock(y.detach(), self.layer2_d) + x_attention = self.attention_2(x) + y_attention = self.attention_2_d(y) + x = torch.mul(x, x_attention) + y = torch.mul(y, y_attention) + x = x + y + features += [skip_rgb.detach()] + # block 3 + x, skip_rgb = self.forward_resblock(x.detach(), self.layer3) + y, skip_depth = self.forward_resblock(y.detach(), self.layer3_d) + x_attention = self.attention_3(x) + y_attention = self.attention_3_d(y) + x = torch.mul(x, x_attention) + y = torch.mul(y, y_attention) + x = x + y + features += [skip_rgb.detach()] + # block 4 + x, skip_rgb = self.forward_resblock(x.detach(), self.layer4) + y, skip_depth = self.forward_resblock(y.detach(), self.layer4_d) + x_attention = self.attention_4(x) + y_attention = self.attention_4_d(y) + x = torch.mul(x, x_attention) + y = torch.mul(y, y_attention) + x = x + y + features += [self.spp.forward(x)] + return features + + + def forward_up(self, features): + features = features[::-1] + + x = features[0] + + upsamples = [] + i = 0 + for skip, up in zip(features[1:], self.upsample): + i += 1 + #print(len(self.upsample)) + if i < len(self.upsample): + x = up(x, skip) + else: + x = up(x, skip) + upsamples += [x] + return x, {'features': features, 'upsamples': upsamples} + + def forward(self, rgb, depth = None): + if depth is None: + return self.forward_up(self.forward_down(rgb)) + else: + return self.forward_up(self.forward_down_fusion(rgb, depth)) + + def _load_resnet_pretrained(self, url): + pretrain_dict = model_zoo.load_url(model_urls[url]) + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + # print('%%%%% ', k) + if k in state_dict: + if k.startswith('conv1'): + model_dict[k] = v + # print('##### ', k) + model_dict[k.replace('conv1', 'conv1_d')] = torch.mean(v, 1).data. \ + view_as(state_dict[k.replace('conv1', 'conv1_d')]) + + elif k.startswith('bn1'): + model_dict[k] = v + model_dict[k.replace('bn1', 'bn1_d')] = v + elif k.startswith('layer'): + model_dict[k] = v + model_dict[k[:6]+'_d'+k[6:]] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +def resnet18(pretrained=True, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) + print('pretrained dict loaded sucessfully') + return model \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/models/resnet/resnet_single_scale_single_attention_unseen.py b/examples/lifelong_learning/RFNet/models/resnet/resnet_single_scale_single_attention_unseen.py new file mode 100644 index 000000000..9668734e3 --- /dev/null +++ b/examples/lifelong_learning/RFNet/models/resnet/resnet_single_scale_single_attention_unseen.py @@ -0,0 +1,396 @@ +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from itertools import chain +import torch.utils.checkpoint as cp +import cv2 +import numpy as np + +from ..util import _Upsample, SpatialPyramidPooling + +__all__ = ['ResNet', 'resnet18'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) + +def _bn_function_factory(conv, norm, relu=None): + """return a conv-bn-relu function""" + def bn_function(x): + x = conv(x) + if norm is not None: + x = norm(x) + if relu is not None: + x = relu(x) + return x + + return bn_function + + +def do_efficient_fwd(block, x, efficient): + if efficient and x.requires_grad: + return cp.checkpoint(block, x) + else: + return block(x) + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = num_channels // groups + + # reshape + x = x.view(batchsize, groups, + channels_per_group, height, width) + + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + x = x.view(batchsize, -1, height, width) + + return x + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=False, use_bn=True): + super(BasicBlock, self).__init__() + self.use_bn = use_bn + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) if self.use_bn else None + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) if self.use_bn else None + self.downsample = downsample + self.stride = stride + self.efficient = efficient + + def forward(self, x): + residual = x + + bn_1 = _bn_function_factory(self.conv1, self.bn1, self.relu) + bn_2 = _bn_function_factory(self.conv2, self.bn2) + + out = do_efficient_fwd(bn_1, x, self.efficient) + out = do_efficient_fwd(bn_2, out, self.efficient) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + relu = self.relu(out) + + return relu, out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=True, use_bn=True): + super(Bottleneck, self).__init__() + self.use_bn = use_bn + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) if self.use_bn else None + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) if self.use_bn else None + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) if self.use_bn else None + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.efficient = efficient + + def forward(self, x): + residual = x + + bn_1 = _bn_function_factory(self.conv1, self.bn1, self.relu) + bn_2 = _bn_function_factory(self.conv2, self.bn2, self.relu) + bn_3 = _bn_function_factory(self.conv3, self.bn3, self.relu) + + out = do_efficient_fwd(bn_1, x, self.efficient) + out = do_efficient_fwd(bn_2, out, self.efficient) + out = do_efficient_fwd(bn_3, out, self.efficient) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + relu = self.relu(out) + + return relu, out + + +class ResNet(nn.Module): + def __init__(self, block, layers, *, num_features=128, k_up=3, efficient=True, use_bn=True, + spp_grids=(8, 4, 2, 1), spp_square_grid=False, **kwargs): + super(ResNet, self).__init__() + self.inplanes = 64 + self.efficient = efficient + self.use_bn = use_bn + + # rgb branch + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) if self.use_bn else lambda x: x + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # depth branch + self.conv1_d = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False) + self.bn1_d = nn.BatchNorm2d(64) if self.use_bn else lambda x: x + self.relu_d = nn.ReLU(inplace=True) + self.maxpool_d = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + upsamples = [] + # 修改 _make_layer_rgb _make_layer + self.layer1 = self._make_layer_rgb(block, 64, 64, layers[0]) + self.layer1_d = self._make_layer_d(block, 64, 64, layers[0]) + self.attention_1 = self.attention(64) + self.attention_1_d = self.attention(64) + upsamples += [_Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up)] # num_maps_in, skip_maps_in, num_maps_out, k: kernel size of blend conv + + self.layer2 = self._make_layer_rgb(block, 64, 128, layers[1], stride=2) + self.layer2_d = self._make_layer_d(block, 64, 128, layers[1], stride=2) + self.attention_2 = self.attention(128) + self.attention_2_d = self.attention(128) + upsamples += [_Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up)] + + self.layer3 = self._make_layer_rgb(block, 128, 256, layers[2], stride=2) + self.layer3_d = self._make_layer_d(block, 128, 256, layers[2], stride=2) + self.attention_3 = self.attention(256) + self.attention_3_d = self.attention(256) + upsamples += [_Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up)] + + self.layer4 = self._make_layer_rgb(block, 256, 512, layers[3], stride=2) + self.layer4_d = self._make_layer_d(block, 256, 512, layers[3], stride=2) + self.attention_4 = self.attention(512) + self.attention_4_d = self.attention(512) + + self.fine_tune = [self.conv1, self.maxpool, self.layer1, self.layer2, self.layer3, self.layer4, + self.conv1_d, self.maxpool_d, self.layer1_d, self.layer2_d, self.layer3_d, self.layer4_d] + if self.use_bn: + self.fine_tune += [self.bn1, self.bn1_d, self.attention_1, self.attention_1_d, self.attention_2, self.attention_2_d, + self.attention_3, self.attention_3_d, self.attention_4, self.attention_4_d] + + num_levels = 3 + self.spp_size = num_features + bt_size = self.spp_size + + level_size = self.spp_size // num_levels + + self.spp = SpatialPyramidPooling(self.inplanes, num_levels, bt_size=bt_size, level_size=level_size, + out_size=self.spp_size, grids=spp_grids, square_grid=spp_square_grid, + bn_momentum=0.01 / 2, use_bn=self.use_bn) + self.upsample = nn.ModuleList(list(reversed(upsamples))) + + self.random_init = [ self.spp, self.upsample] + + self.num_features = num_features + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def output_num(self): + return self.__in_features + + def _make_layer_rgb(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + layers = [nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False)] + if self.use_bn: + layers += [nn.BatchNorm2d(planes * block.expansion)] + downsample = nn.Sequential(*layers) + layers = [block(inplanes, planes, stride, downsample, efficient=self.efficient, use_bn=self.use_bn)] + inplanes = planes * block.expansion + for i in range(1, blocks): + layers += [block(inplanes, planes, efficient=self.efficient, use_bn=self.use_bn)] + + return nn.Sequential(*layers) + + def _make_layer_d(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + layers = [nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False)] + if self.use_bn: + layers += [nn.BatchNorm2d(planes * block.expansion)] + downsample = nn.Sequential(*layers) + layers = [block(inplanes, planes, stride, downsample, efficient=self.efficient, use_bn=self.use_bn)] + inplanes = planes * block.expansion + self.inplanes = inplanes + for i in range(1, blocks): + layers += [block(inplanes, planes, efficient=self.efficient, use_bn=self.use_bn)] + + return nn.Sequential(*layers) + + def channel_attention(self, rgb_skip, depth_skip, attention): + assert rgb_skip.shape == depth_skip.shape, 'rgb skip shape:{} != depth skip shape:{}'.format(rgb_skip.shape, depth_skip.shape) + # single_attenton + rgb_attention = attention(rgb_skip) + depth_attention = attention(depth_skip) + rgb_after_attention = torch.mul(rgb_skip, rgb_attention) + depth_after_attention = torch.mul(depth_skip, depth_attention) + skip_after_attention = rgb_after_attention + depth_after_attention + return skip_after_attention + + def attention(self, num_channels): + pool_attention = nn.AdaptiveAvgPool2d(1) + conv_attention = nn.Conv2d(num_channels, num_channels, kernel_size=1) + activate = nn.Sigmoid() + + return nn.Sequential(pool_attention, conv_attention, activate) + + + def random_init_params(self): + return chain(*[f.parameters() for f in self.random_init]) + + def fine_tune_params(self): + return chain(*[f.parameters() for f in self.fine_tune]) + + def forward_resblock(self, x, layers): + skip = None + for l in layers: + x = l(x) + if isinstance(x, tuple): + x, skip = x + return x, skip + + def forward_down(self, rgb): + x = self.conv1(rgb) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + features = [] + x, skip = self.forward_resblock(x, self.layer1) + features += [skip] + x, skip = self.forward_resblock(x, self.layer2) + features += [skip] + x, skip = self.forward_resblock(x, self.layer3) + features += [skip] + x, skip = self.forward_resblock(x, self.layer4) + features += [self.spp.forward(skip)] + features_da = self.spp.forward(skip) + return features, features_da + + def forward_down_fusion(self, rgb, depth): + x = self.conv1(rgb) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + depth = depth.unsqueeze(1) + y = self.conv1_d(depth) + y = self.bn1_d(y) + y = self.relu_d(y) + y = self.maxpool_d(y) + + features = [] + # block 1 + x, skip_rgb = self.forward_resblock(x, self.layer1) + y, skip_depth = self.forward_resblock(y, self.layer1_d) + x_attention = self.attention_1(x) + y_attention = self.attention_1_d(y) + x = torch.mul(x, x_attention) + y = torch.mul(y, y_attention) + x = x + y + features += [skip_rgb] + # block 2 + x, skip_rgb = self.forward_resblock(x, self.layer2) + y, skip_depth = self.forward_resblock(y, self.layer2_d) + x_attention = self.attention_2(x) + y_attention = self.attention_2_d(y) + x = torch.mul(x, x_attention) + y = torch.mul(y, y_attention) + x = x + y + features += [skip_rgb] + # block 3 + x, skip_rgb = self.forward_resblock(x, self.layer3) + y, skip_depth = self.forward_resblock(y, self.layer3_d) + x_attention = self.attention_3(x) + y_attention = self.attention_3_d(y) + x = torch.mul(x, x_attention) + y = torch.mul(y, y_attention) + x = x + y + features += [skip_rgb] + # block 4 + x, skip_rgb = self.forward_resblock(x, self.layer4) + y, skip_depth = self.forward_resblock(y, self.layer4_d) + x_attention = self.attention_4(x) + y_attention = self.attention_4_d(y) + x = torch.mul(x, x_attention) + y = torch.mul(y, y_attention) + x = x + y + features += [self.spp.forward(x)] + features_da = self.spp.forward(x) + return features, features_da + + + def forward_up(self, features): + features = features[::-1] + + x = features[0] + + upsamples = [] + for skip, up in zip(features[1:], self.upsample): + x = up(x, skip) + upsamples += [x] + return x, {'features': features, 'upsamples': upsamples} + + def forward(self, rgb, depth = None): + if depth is None: + down_features, da_features = self.forward_down(rgb) + x, additional = self.forward_up(down_features) + return x, additional, da_features#self.forward_up(self.forward_down(rgb)), self.forward_down(rgb) + else: + down_features, da_features = self.forward_down_fusion(rgb, depth) + x, additional = self.forward_up(down_features) + #print(down_features.shape) + return x, additional, da_features#self.forward_up(self.forward_down_fusion(rgb, depth)), self.forward_down_fusion(rgb, depth) + + def _load_resnet_pretrained(self, url): + pretrain_dict = model_zoo.load_url(model_urls[url]) + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + # print('%%%%% ', k) + if k in state_dict: + if k.startswith('conv1'): + model_dict[k] = v + # print('##### ', k) + model_dict[k.replace('conv1', 'conv1_d')] = torch.mean(v, 1).data. \ + view_as(state_dict[k.replace('conv1', 'conv1_d')]) + + elif k.startswith('bn1'): + model_dict[k] = v + model_dict[k.replace('bn1', 'bn1_d')] = v + elif k.startswith('layer'): + model_dict[k] = v + model_dict[k[:6]+'_d'+k[6:]] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +def resnet18(pretrained=True, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) + print('pretrained dict loaded sucessfully') + return model \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/models/rfnet.py b/examples/lifelong_learning/RFNet/models/rfnet.py new file mode 100644 index 000000000..87f02863a --- /dev/null +++ b/examples/lifelong_learning/RFNet/models/rfnet.py @@ -0,0 +1,27 @@ +import torch.nn as nn +from itertools import chain # 串联多个迭代对象 + +from .util import _BNReluConv, upsample + + +class RFNet(nn.Module): + def __init__(self, backbone, num_classes, use_bn=True): + super(RFNet, self).__init__() + self.backbone = backbone + self.num_classes = num_classes + print(self.backbone.num_features) + self.logits = _BNReluConv(self.backbone.num_features, self.num_classes, batch_norm=use_bn) + + def forward(self, rgb_inputs, depth_inputs = None): + x, additional = self.backbone(rgb_inputs, depth_inputs) + logits = self.logits.forward(x) + logits_upsample = upsample(logits, rgb_inputs.shape[2:]) + #print(logits_upsample.size) + return logits_upsample + + + def random_init_params(self): + return chain(*([self.logits.parameters(), self.backbone.random_init_params()])) + + def fine_tune_params(self): + return self.backbone.fine_tune_params() diff --git a/examples/lifelong_learning/RFNet/models/rfnet_for_unseen.py b/examples/lifelong_learning/RFNet/models/rfnet_for_unseen.py new file mode 100644 index 000000000..f61eb1ce2 --- /dev/null +++ b/examples/lifelong_learning/RFNet/models/rfnet_for_unseen.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from itertools import chain # 串联多个迭代对象 + +from .util import _BNReluConv, upsample + + +class RFNet(nn.Module): + def __init__(self, backbone, num_classes, use_bn=True): + super(RFNet, self).__init__() + self.backbone = backbone + self.num_classes = num_classes + #self.bottleneck = _BNReluConv(self.backbone.num_features, 128, k = 3, batch_norm=use_bn) + #self.logits = _BNReluConv(128, self.num_classes+1, k = 1, batch_norm=use_bn) + self.logits = _BNReluConv(self.backbone.num_features, self.num_classes, batch_norm=use_bn) + #self.logits_target = _BNReluConv(self.backbone.num_features, self.num_classes, batch_norm=use_bn) + self.logits_aux = _BNReluConv(self.backbone.num_features, self.num_classes, batch_norm=use_bn) + + def forward(self, rgb_inputs, depth_inputs = None): + x, additional, da_features = self.backbone(rgb_inputs, depth_inputs) + #print(additional['features'][0].shape) + #bottleneck = self.bottleneck(x) + logits = self.logits.forward(x) + logits_aux = self.logits_aux.forward(x) + #print(logits_aux.shape) + logits_upsample = upsample(logits, rgb_inputs.shape[2:]) + logits_aux_upsample = upsample(logits_aux, rgb_inputs.shape[2:]) + return logits_upsample, logits_aux_upsample, da_features + + def random_init_params(self): + return chain(*([self.logits.parameters(), self.logits_aux.parameters(), self.backbone.random_init_params()])) + + def fine_tune_params(self): + return self.backbone.fine_tune_params() diff --git a/examples/lifelong_learning/RFNet/models/util.py b/examples/lifelong_learning/RFNet/models/util.py new file mode 100644 index 000000000..5c86e7598 --- /dev/null +++ b/examples/lifelong_learning/RFNet/models/util.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False) +batchnorm_momentum = 0.01 / 2 + + +def get_n_params(parameters): + pp = 0 + for p in parameters: + nn = 1 + for s in list(p.size()): + nn = nn * s + pp += nn + return pp + + +class _BNReluConv(nn.Sequential): + def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1): + super(_BNReluConv, self).__init__() + if batch_norm: + self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum)) + self.add_module('relu', nn.ReLU(inplace=batch_norm is True)) + padding = k // 2 # same conv + self.add_module('conv', nn.Conv2d(num_maps_in, num_maps_out, + kernel_size=k, padding=padding, bias=bias, dilation=dilation)) + + +class _Upsample(nn.Module): + def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3): + super(_Upsample, self).__init__() + print(f'Upsample layer: in = {num_maps_in}, skip = {skip_maps_in}, out = {num_maps_out}') + self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn) + self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn) + + def forward(self, x, skip): + skip = self.bottleneck.forward(skip) + skip_size = skip.size()[2:4] + x = upsample(x, skip_size) + x = x + skip + x = self.blend_conv.forward(x) + return x + + +class SpatialPyramidPooling(nn.Module): + def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128, + grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True): + super(SpatialPyramidPooling, self).__init__() + self.grids = grids + self.square_grid = square_grid + self.spp = nn.Sequential() + self.spp.add_module('spp_bn', + _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) + num_features = bt_size + final_size = num_features + for i in range(num_levels): + final_size += level_size + self.spp.add_module('spp' + str(i), + _BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) + self.spp.add_module('spp_fuse', + _BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) + + def forward(self, x): + levels = [] + target_size = x.size()[2:4] + + ar = target_size[1] / target_size[0] + + x = self.spp[0].forward(x) + levels.append(x) + num = len(self.spp) - 1 + + for i in range(1, num): + if not self.square_grid: + grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1]))) + x_pooled = F.adaptive_avg_pool2d(x, grid_size) + else: + x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1]) + level = self.spp[i].forward(x_pooled) + + level = upsample(level, target_size) + levels.append(level) + x = torch.cat(levels, 1) + x = self.spp[-1].forward(x) + return x + + +class _UpsampleBlend(nn.Module): + def __init__(self, num_features, use_bn=True): + super(_UpsampleBlend, self).__init__() + self.blend_conv = _BNReluConv(num_features, num_features, k=3, batch_norm=use_bn) + + def forward(self, x, skip): + skip_size = skip.size()[2:4] + x = upsample(x, skip_size) + x = x + skip + x = self.blend_conv.forward(x) + return x diff --git a/examples/lifelong_learning/RFNet/mypath.py b/examples/lifelong_learning/RFNet/mypath.py new file mode 100644 index 000000000..640544e7d --- /dev/null +++ b/examples/lifelong_learning/RFNet/mypath.py @@ -0,0 +1,20 @@ +class Path(object): + @staticmethod + def db_root_dir(dataset): + if dataset == 'cityscapes': + return '/home/robo/m0063/project/RFNet-master/Data/cityscapes/' # folder that contains leftImg8bit/ + elif dataset == 'citylostfound': + return '/home/robo/m0063/project/RFNet-master/Data/cityscapesandlostandfound/' # folder that mixes Cityscapes and Lost and Found + elif dataset == 'cityrand': + return '/home/robo/m0063/project/RFNet-master/Data/cityrand/' + elif dataset == 'target': + return '/home/robo/m0063/project/RFNet-master/Data/target/' + elif dataset == 'xrlab': + return '/home/robo/m0063/project/RFNet-master/Data/xrlab/' + elif dataset == 'e1': + return '/home/robo/m0063/project/RFNet-master/Data/e1/' + elif dataset == 'mapillary': + return '/home/robo/m0063/project/RFNet-master/Data/mapillary/' + else: + print('Dataset {} not available.'.format(dataset)) + raise NotImplementedError diff --git a/examples/lifelong_learning/RFNet/predict.py b/examples/lifelong_learning/RFNet/predict.py new file mode 100644 index 000000000..82b527a20 --- /dev/null +++ b/examples/lifelong_learning/RFNet/predict.py @@ -0,0 +1,98 @@ +import os +os.environ['BACKEND_TYPE'] = 'PYTORCH' +# set at yaml +# os.environ["PREDICT_RESULT_DIR"] = "./inference_results" +# os.environ["EDGE_OUTPUT_URL"] = "./edge_kb" +# os.environ["video_url"] = "./video/radio.mp4" +# os.environ["MODEL_URLS"] = "./cloud_next_kb/index.pkl" + + +import cv2 +import time +import torch +import numpy as np +from PIL import Image +import base64 +import tempfile +import warnings +from io import BytesIO + +from sedna.datasources import BaseDataSource +from sedna.core.lifelong_learning import LifelongLearning +from sedna.common.config import Context + +from dataloaders import custom_transforms as tr +from torchvision import transforms + +from accuracy import accuracy +from basemodel import preprocess, val_args, Model + +def preprocess(samples): + composed_transforms = transforms.Compose([ + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + data = BaseDataSource(data_type="test") + data.x = [(composed_transforms(samples), "")] + return data + +def init_ll_job(): + estimator = Model() + + task_allocation = { + "method": "TaskAllocationByOrigin", + "param": { + "origins": ["real", "sim"], + "default": "real" + } + } + unseen_task_allocation = { + "method": "UnseenTaskAllocationDefault" + } + + ll_job = LifelongLearning( + estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=task_allocation, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=unseen_task_allocation, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None) + + return ll_job + +def predict(): + ll_job = init_ll_job() + + camera_address = Context.get_parameters('video_url') + # use video streams for testing + camera = cv2.VideoCapture(camera_address) + fps = 10 + nframe = 0 + while 1: + ret, input_yuv = camera.read() + if not ret: + time.sleep(5) + camera = cv2.VideoCapture(camera_address) + continue + + if nframe % fps: + nframe += 1 + continue + + img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB) + nframe += 1 + if nframe % 1000 == 1: # logs every 1000 frames + warnings.warn(f"camera is open, current frame index is {nframe}") + + img_rgb = cv2.resize(np.array(img_rgb), (2048, 1024), interpolation=cv2.INTER_CUBIC) + img_rgb = Image.fromarray(img_rgb) + sample = {'image': img_rgb, "depth": img_rgb, "label": img_rgb} + data = preprocess(sample) + print("Inference results:", ll_job.inference(data=data)) + +if __name__ == '__main__': + predict() diff --git a/examples/lifelong_learning/RFNet/run_server.py b/examples/lifelong_learning/RFNet/run_server.py new file mode 100644 index 000000000..438cd70a5 --- /dev/null +++ b/examples/lifelong_learning/RFNet/run_server.py @@ -0,0 +1,252 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from io import BytesIO +from typing import Optional, Any + +import cv2 +import numpy as np +from PIL import Image +import uvicorn +import time +from pydantic import BaseModel +from fastapi import FastAPI, UploadFile, File +from fastapi.routing import APIRoute +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse +import sedna_predict +from sedna.common.utils import get_host_ip +from dataloaders.datasets.cityscapes import CityscapesSegmentation + + +class ImagePayload(BaseModel): + image: UploadFile = File(...) + depth: Optional[UploadFile] = None + + +class ResultModel(BaseModel): + type: int = 0 + box: Any = None + curr: str = None + future: str = None + img: str = None + + +class ResultResponse(BaseModel): + msg: str = "" + result: Optional[ResultModel] = None + code: int + + +class BaseServer: + # pylint: disable=too-many-instance-attributes,too-many-arguments + DEBUG = True + WAIT_TIME = 15 + + def __init__( + self, + servername: str, + host: str, + http_port: int = 8080, + grpc_port: int = 8081, + workers: int = 1, + ws_size: int = 16 * 1024 * 1024, + ssl_key=None, + ssl_cert=None, + timeout=300): + self.server_name = servername + self.app = None + self.host = host or '0.0.0.0' + self.http_port = http_port or 80 + self.grpc_port = grpc_port + self.workers = workers + self.keyfile = ssl_key + self.certfile = ssl_cert + self.ws_size = int(ws_size) + self.timeout = int(timeout) + protocal = "https" if self.certfile else "http" + self.url = f"{protocal}://{self.host}:{self.http_port}" + + def run(self, app, **kwargs): + if hasattr(app, "add_middleware"): + app.add_middleware( + CORSMiddleware, allow_origins=["*"], allow_credentials=True, + allow_methods=["*"], allow_headers=["*"], + ) + + uvicorn.run( + app, + host=self.host, + port=self.http_port, + ssl_keyfile=self.keyfile, + ssl_certfile=self.certfile, + workers=self.workers, + timeout_keep_alive=self.timeout, + **kwargs) + + def get_all_urls(self): + url_list = [{"path": route.path, "name": route.name} + for route in getattr(self.app, 'routes', [])] + return url_list + + +class InferenceServer(BaseServer): # pylint: disable=too-many-arguments + """ + rest api server for inference + """ + + def __init__( + self, + servername, + host: str, + http_port: int = 5000, + max_buffer_size: int = 104857600, + workers: int = 1): + super( + InferenceServer, + self).__init__( + servername=servername, + host=host, + http_port=http_port, + workers=workers) + + self.job, self.detection_validator = sedna_predict.init_ll_job() + + self.max_buffer_size = max_buffer_size + self.app = FastAPI( + routes=[ + APIRoute( + f"/{servername}", + self.model_info, + methods=["GET"], + ), + APIRoute( + f"/{servername}/predict", + self.predict, + methods=["POST"], + response_model=ResultResponse + ), + ], + log_level="trace", + timeout=600, + ) + self.index_frame = 0 + + def start(self): + return self.run(self.app) + + @staticmethod + def model_info(): + return HTMLResponse( + """

Welcome to the RestNet API!

+

To use this service, send a POST HTTP request to {this-url}/predict

+

The JSON payload has the following format: {"image": "BASE64_STRING_OF_IMAGE", + "depth": "BASE64_STRING_OF_DEPTH"}

+ """) + + async def predict(self, image: UploadFile = File(...), depth: Optional[UploadFile] = None) -> ResultResponse: + contents = await image.read() + recieve_img_time = time.time() + print("Recieve image from the robo:", recieve_img_time) + + image = Image.open(BytesIO(contents)).convert('RGB') + + img_dep = None + self.index_frame = self.index_frame + 1 + + if depth: + depth_contents = await depth.read() + depth = Image.open(BytesIO(depth_contents)).convert('RGB') + img_dep = cv2.resize(np.array(depth), (2048, 1024), interpolation=cv2.INTER_CUBIC) + img_dep = Image.fromarray(img_dep) + + img_rgb = cv2.resize(np.array(image), (2048, 1024), interpolation=cv2.INTER_CUBIC) + img_rgb = Image.fromarray(img_rgb) + + sample = {'image': img_rgb, "depth": img_dep, "label": img_rgb} + results = sedna_predict.predict(self.job, data=sample, validator=self.detection_validator) + + predict_finish_time = time.time() + print(f"Prediction costs {predict_finish_time - recieve_img_time} seconds") + + post_process = True + if results["result"]["box"] is None: + results["result"]["curr"] = None + results["result"]["future"] = None + elif post_process: + curr, future = get_curb(results["result"]["box"]) + results["result"]["curr"] = curr + results["result"]["future"] = future + results["result"]["box"] = None + print("Post process cost at worker:", (time.time()-predict_finish_time)) + else: + results["result"]["curr"] = None + results["result"]["future"] = None + + print("Result transmit to robo time:", time.time()) + return results + +def parse_result(label, count): + label_map = ['road', 'sidewalk', ] + count_d = dict(zip(label, count)) + curb_count = count_d.get(19, 0) + if curb_count / np.sum(count) > 0.3: + return "curb" + r = sorted(label, key=count_d.get, reverse=True)[0] + try: + c = label_map[r] + except: + c = "other" + + return c + +def get_curb(results): + results = np.array(results[0]) + input_height, input_width = results.shape + + closest = np.array([ + [0, int(input_height)], + [int(input_width), + int(input_height)], + [int(0.118 * input_width + .5), + int(.8 * input_height + .5)], + [int(0.882 * input_width + .5), + int(.8 * input_height + .5)], + ]) + + future = np.array([ + [int(0.118 * input_width + .5), + int(.8 * input_height + .5)], + [int(0.882 * input_width + .5), + int(.8 * input_height + .5)], + [int(.765 * input_width + .5), + int(.66 * input_height + .5)], + [int(.235 * input_width + .5), + int(.66 * input_height + .5)] + ]) + + mask = np.zeros((input_height, input_width), dtype=np.uint8) + mask = cv2.fillPoly(mask, [closest], 1) + mask = cv2.fillPoly(mask, [future], 2) + d1, c1 = np.unique(results[mask == 1], return_counts=True) + d2, c2 = np.unique(results[mask == 2], return_counts=True) + c = parse_result(d1, c1) + f = parse_result(d2, c2) + + return c, f + +if __name__ == '__main__': + web_app = InferenceServer("lifelong-learning-robo", host=get_host_ip()) + web_app.start() diff --git a/examples/lifelong_learning/RFNet/sedna_evaluate.py b/examples/lifelong_learning/RFNet/sedna_evaluate.py new file mode 100644 index 000000000..566333472 --- /dev/null +++ b/examples/lifelong_learning/RFNet/sedna_evaluate.py @@ -0,0 +1,50 @@ +import os +os.environ['BACKEND_TYPE'] = 'PYTORCH' +# os.environ["KB_SERVER"] = "http://0.0.0.0:9020" +# os.environ["test_dataset_url"] = "./data_txt/sedna_data.txt" +# os.environ["MODEL_URLS"] = "./cloud_next_kb/index.pkl" +# os.environ["operator"] = "<" +# os.environ["model_threshold"] = "0" + +from sedna.core.lifelong_learning import LifelongLearning +from sedna.datasources import IndexDataParse +from sedna.common.config import Context + +from accuracy import accuracy +from basemodel import Model + +def _load_txt_dataset(dataset_url): + # use original dataset url + original_dataset_url = Context.get_parameters('original_dataset_url') + return os.path.join(os.path.dirname(original_dataset_url), dataset_url) + +def eval(): + estimator = Model() + eval_dataset_url = Context.get_parameters("test_dataset_url") + eval_data = IndexDataParse(data_type="eval", func=_load_txt_dataset) + eval_data.parse(eval_dataset_url, use_raw=False) + + task_allocation = { + "method": "TaskAllocationByOrigin", + "param": { + "origins": ["real", "sim"] + } + } + + ll_job = LifelongLearning(estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=task_allocation, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=None, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None + ) + + ll_job.evaluate(eval_data, metrics=accuracy) + + +if __name__ == '__main__': + print(eval()) diff --git a/examples/lifelong_learning/RFNet/sedna_predict.py b/examples/lifelong_learning/RFNet/sedna_predict.py new file mode 100644 index 000000000..b32c01d2d --- /dev/null +++ b/examples/lifelong_learning/RFNet/sedna_predict.py @@ -0,0 +1,132 @@ +import os + +os.environ['BACKEND_TYPE'] = 'PYTORCH' +# os.environ["UNSEEN_SAVE_URL"] = "s3://kubeedge/sedna-robo/unseen_samples/" +# set at yaml +# os.environ["PREDICT_RESULT_DIR"] = "./inference_results" +os.environ["TEST_DATASET_URL"] = "./data_txt/door_test.txt" +os.environ["EDGE_OUTPUT_URL"] = "./edge_kb" +os.environ["ORIGINAL_DATASET_URL"] = "/tmp" + +import torch +import numpy as np +from PIL import Image +import base64 +import tempfile +from io import BytesIO +from torchvision.transforms import ToPILImage +from torchvision import transforms +from torch.utils.data import DataLoader + +from sedna.datasources import IndexDataParse +from sedna.core.lifelong_learning import LifelongLearning +from sedna.common.config import Context + +from eval import Validator +from accuracy import accuracy +from basemodel import preprocess, val_args, Model +from dataloaders.utils import Colorize +from dataloaders import custom_transforms as tr +from dataloaders.datasets.cityscapes import CityscapesSegmentation + +def _load_txt_dataset(dataset_url): + # use original dataset url, + # see https://github.com/kubeedge/sedna/issues/35 + original_dataset_url = Context.get_parameters('original_dataset_url') + return os.path.join(os.path.dirname(original_dataset_url), dataset_url) + +def fetch_data(): + test_dataset_url = Context.get_parameters("test_dataset_url") + test_data = IndexDataParse(data_type="test", func=_load_txt_dataset) + test_data.parse(test_dataset_url, use_raw=False) + return test_data + +def pre_data_process(samples): + composed_transforms = transforms.Compose([ + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + data = BaseDataSource(data_type="test") + data.x = [(composed_transforms(samples), "")] + return data + +def post_process(res, is_unseen_task): + if is_unseen_task: + res, base64_string = None, None + else: + res = res[0].tolist() + + type = 0 if not is_unseen_task else 1 + mesg = { + "msg": "", + "result": { + "type": type, + "box": res + }, + "code": 0 + } + return mesg + +def image_merge(raw_img, result): + raw_img = ToPILImage()(raw_img) + + pre_colors = Colorize()(torch.from_numpy(result)) + pre_color_image = ToPILImage()(pre_colors[0]) # pre_colors.dtype = float64 + + image = raw_img.resize(pre_color_image.size, Image.BILINEAR) + image = image.convert('RGBA') + label = pre_color_image.convert('RGBA') + image = Image.blend(image, label, 0.6) + with tempfile.NamedTemporaryFile(suffix='.png') as f: + image.save(f.name) + + with open(f.name, 'rb') as open_file: + byte_content = open_file.read() + base64_bytes = base64.b64encode(byte_content) + base64_string = base64_bytes.decode('utf-8') + return base64_string + +def init_ll_job(): + estimator = Model() + inference_integrate = { + "method": "BBoxInferenceIntegrate" + } + unseen_task_allocation = { + "method": "UnseenTaskAllocationDefault" + } + unseen_sample_recognition = { + "method": "SampleRegonitionByRFNet" + } + + ll_job = LifelongLearning( + estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=None, + task_remodeling=None, + inference_integrate=inference_integrate, + task_update_decision=None, + unseen_task_allocation=unseen_task_allocation, + unseen_sample_recognition=unseen_sample_recognition, + unseen_sample_re_recognition=None) + + args = val_args() + args.weight_path = "./models/detection_model.pth" + args.num_class = 31 + + return ll_job, Validator(args, unseen_detection=True) + +def predict(ll_job, data=None, validator=None): + if data: + data = pre_data_process(data) + else: + data = fetch_data() + data.x = preprocess(data.x) + + res, is_unseen_task, _ = ll_job.inference( + data, validator=validator, initial=False) + return post_process(res, is_unseen_task) + +if __name__ == '__main__': + ll_job, validator = init_ll_job() + print("Inference result:", predict(ll_job, validator=validator)) diff --git a/examples/lifelong_learning/RFNet/sedna_train.py b/examples/lifelong_learning/RFNet/sedna_train.py new file mode 100644 index 000000000..1c99361aa --- /dev/null +++ b/examples/lifelong_learning/RFNet/sedna_train.py @@ -0,0 +1,78 @@ +import os +os.environ['BACKEND_TYPE'] = 'PYTORCH' +os.environ["OUTPUT_URL"] = "./cloud_kb/" +# os.environ['CLOUD_KB_INDEX'] = "./cloud_kb/index.pkl" +os.environ["TRAIN_DATASET_URL"] = "./data_txt/sedna_data.txt" +os.environ["KB_SERVER"] = "http://0.0.0.0:9020" +os.environ["HAS_COMPLETED_INITIAL_TRAINING"] = "false" + +from sedna.common.file_ops import FileOps +from sedna.datasources import IndexDataParse +from sedna.common.config import Context, BaseConfig +from sedna.core.lifelong_learning import LifelongLearning + +from basemodel import Model + +def _load_txt_dataset(dataset_url): + # use original dataset url + original_dataset_url = Context.get_parameters('original_dataset_url') + return os.path.join(os.path.dirname(original_dataset_url), dataset_url) + +def train(estimator, train_data): + task_definition = { + "method": "TaskDefinitionByOrigin", + "param": { + "origins": ["real", "sim"] + } + } + + task_allocation = { + "method": "TaskAllocationByOrigin", + "param": { + "origins": ["real", "sim"] + } + } + + ll_job = LifelongLearning(estimator, + task_definition=task_definition, + task_relationship_discovery=None, + task_allocation=task_allocation, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=None, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None + ) + + ll_job.train(train_data) + +def update(estimator, train_data): + ll_job = LifelongLearning(estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=None, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=None, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None + ) + + ll_job.update(train_data) + +def run(): + estimator = Model() + train_dataset_url = BaseConfig.train_dataset_url + train_data = IndexDataParse(data_type="train") + train_data.parse(train_dataset_url, use_raw=False) + + is_completed_initilization = str(Context.get_parameters("HAS_COMPLETED_INITIAL_TRAINING", "false")).lower() + if is_completed_initilization == "false": + train(estimator, train_data) + else: + update(estimator, train_data) + +if __name__ == '__main__': + run() diff --git a/examples/lifelong_learning/RFNet/test.py b/examples/lifelong_learning/RFNet/test.py new file mode 100644 index 000000000..fd9cd6573 --- /dev/null +++ b/examples/lifelong_learning/RFNet/test.py @@ -0,0 +1,52 @@ +import numpy as np +import seaborn as sns +import pandas as pd +import matplotlib.pyplot as plt + +CPA_results = np.load("./cpa_results.npy").T +ratios = [0.3, 0.5, 0.6, 0.7, 0.8, 0.9] +ratio_counts = np.zeros((len(CPA_results), len(ratios)), dtype=float) + +for i in range(len(CPA_results)): + for j in range(len(ratios)): + result = CPA_results[i] + result = result[result <= ratios[j]] + + ratio_counts[i][j] = len(result) / 275 + +plt.figure(figsize=(45, 10)) +ratio_counts = pd.DataFrame(data=ratio_counts.T, index=ratios) +sns.heatmap(data=ratio_counts, annot=True, cmap="YlGnBu", annot_kws={'fontsize': 15}) +plt.xticks(fontsize=20) +plt.yticks(fontsize=25) +plt.xlabel("Test images", fontsize=25) +plt.ylabel("Ratio of PA ranges", fontsize=25) +plt.savefig("./figs/ratio_count.png") +plt.show() + + +# data = pd.DataFrame(CPA_results.T) +# +# plt.figure(figsize=(30, 15)) +# cpa_result = pd.DataFrame(data=data) +# sns.heatmap(data=cpa_result) +# plt.savefig("./figs/heatmap_pa.png") +# plt.show() +# +# plt.figure(figsize=(30, 15)) +# cpa_result = pd.DataFrame(data=data[:15]) +# sns.heatmap(data=cpa_result) +# plt.savefig("./figs/door_heatmap_pa.png") +# plt.show() +# +# plt.figure(figsize=(30, 15)) +# cpa_result = pd.DataFrame(data=data[15:31]) +# sns.heatmap(data=cpa_result) +# plt.savefig("./figs/garden1_heatmap_pa.png") +# plt.show() +# +# plt.figure(figsize=(30, 15)) +# cpa_result = pd.DataFrame(data=data[31:]) +# sns.heatmap(data=cpa_result) +# plt.savefig("./figs/garden2_heatmap_pa.png") +# plt.show() diff --git a/examples/lifelong_learning/RFNet/train.py b/examples/lifelong_learning/RFNet/train.py new file mode 100644 index 000000000..ca6c21949 --- /dev/null +++ b/examples/lifelong_learning/RFNet/train.py @@ -0,0 +1,341 @@ +import argparse +import os +import numpy as np +from tqdm import tqdm +import torch + +from mypath import Path +from dataloaders import make_data_loader +from models.rfnet import RFNet +from models.resnet.resnet_single_scale_single_attention import * +from utils.loss import SegmentationLosses +from models.replicate import patch_replication_callback +from utils.calculate_weights import calculate_weigths_labels +from utils.lr_scheduler import LR_Scheduler +from utils.saver import Saver +from utils.summaries import TensorboardSummary +from utils.metrics import Evaluator +from sedna.datasources import BaseDataSource + +class Trainer(object): + def __init__(self, args, train_data=None, valid_data=None): + self.args = args + # Define Saver + self.saver = Saver(args) + self.saver.save_experiment_config() + # Define Tensorboard Summary + self.summary = TensorboardSummary(self.saver.experiment_dir) + self.writer = self.summary.create_summary() + # denormalize for detph image + self.mean_depth = torch.as_tensor(0.12176, dtype=torch.float32, device='cpu') + self.std_depth = torch.as_tensor(0.09752, dtype=torch.float32, device='cpu') + self.nclass = args.num_class + # Define Dataloader + kwargs = {'num_workers': args.workers, 'pin_memory': False} + self.train_loader, self.val_loader, self.test_loader, _ = make_data_loader(args, train_data=train_data, + valid_data=valid_data, **kwargs) + + # Define network + resnet = resnet18(pretrained=True, efficient=False, use_bn=True) + model = RFNet(resnet, num_classes=self.nclass, use_bn=True) + train_params = [{'params': model.random_init_params(), 'lr': args.lr}, + {'params': model.fine_tune_params(), 'lr': 0.1*args.lr, 'weight_decay':args.weight_decay}] + # Define Optimizer + optimizer = torch.optim.Adam(train_params, lr=args.lr, + weight_decay=args.weight_decay) + # Define Criterion + # whether to use class balanced weights + if args.use_balanced_weights: + classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy') + if os.path.isfile(classes_weights_path): + weight = np.load(classes_weights_path) + else: + weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) + weight = torch.from_numpy(weight.astype(np.float32)) + else: + weight = None + # Define loss function + self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) + self.model, self.optimizer = model, optimizer + # Define Evaluator + self.evaluator = Evaluator(self.nclass) + # # Define lr scheduler + self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) + # Using cuda + if args.cuda: + self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) + patch_replication_callback(self.model) + self.model = self.model.cuda() + + # Resuming checkpoint + self.best_pred = 0.0 + if args.resume is not None: + if not os.path.isfile(args.resume): + raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) + print(f"Training: load model from {args.resume}") + checkpoint = torch.load(args.resume, map_location=torch.device("cpu")) + args.start_epoch = checkpoint['epoch'] + # if args.cuda: + # self.model.load_state_dict(checkpoint['state_dict']) + # else: + # self.model.load_state_dict(checkpoint['state_dict']) + self.model.load_state_dict(checkpoint['state_dict']) + if not args.ft: + self.optimizer.load_state_dict(checkpoint['optimizer']) + self.best_pred = checkpoint['best_pred'] + print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) + + # Clear start epoch if fine-tuning + if args.ft: + args.start_epoch = 0 + + def training(self, epoch): + train_loss = 0.0 + print(self.optimizer.state_dict()['param_groups'][0]['lr']) + self.model.train() + tbar = tqdm(self.train_loader) + num_img_tr = len(self.train_loader) + for i, sample in enumerate(tbar): + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + #print(target.shape) + else: + image, target = sample['image'], sample['label'] + print(image.shape) + if self.args.cuda: + image, target = image.cuda(), target.cuda() + if self.args.depth: + depth = depth.cuda() + self.scheduler(self.optimizer, i, epoch, self.best_pred) + self.optimizer.zero_grad() + if self.args.depth: + output = self.model(image, depth) + else: + output = self.model(image) + #print(target.max()) + #print(output.shape) + target[target > self.nclass-1] = 255 + loss = self.criterion(output, target) + loss.backward() + self.optimizer.step() + #print(self.optimizer.state_dict()['param_groups'][0]['lr']) + train_loss += loss.item() + tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) + self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) + # Show 10 * 3 inference results each epoch + if i % (num_img_tr // 10) == 0: + global_step = i + num_img_tr * epoch + if self.args.depth: + self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) + + depth_display = depth[0].cpu().unsqueeze(0) + depth_display = depth_display.mul_(self.std_depth).add_(self.mean_depth) + depth_display = depth_display.numpy() + depth_display = depth_display*255 + depth_display = depth_display.astype(np.uint8) + self.writer.add_image('Depth', depth_display, global_step) + + else: + self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) + + self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) + print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) + print('Loss: %.3f' % train_loss) + + # if self.args.no_val: + # # save checkpoint every epoch + # is_best = False + # checkpoint_path = self.saver.save_checkpoint({ + # 'epoch': epoch + 1, + # 'state_dict': self.model.state_dict(), + # 'optimizer': self.optimizer.state_dict(), + # 'best_pred': self.best_pred, + # }, is_best) + + def validation(self, epoch): + self.model.eval() + self.evaluator.reset() + tbar = tqdm(self.val_loader, desc='\r') + test_loss = 0.0 + for i, (sample, img_path) in enumerate(tbar): + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + # print(f"val image is {image}") + if self.args.cuda: + image, target = image.cuda(), target.cuda() + if self.args.depth: + depth = depth.cuda() + with torch.no_grad(): + if self.args.depth: + output = self.model(image, depth) + else: + output = self.model(image) + target[target > self.nclass-1] = 255 + loss = self.criterion(output, target) + test_loss += loss.item() + tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) + pred = output.data.cpu().numpy() + target = target.cpu().numpy() + pred = np.argmax(pred, axis=1) + # Add batch sample into evaluator + self.evaluator.add_batch(target, pred) + + # Fast test during the training + Acc = self.evaluator.Pixel_Accuracy() + Acc_class = self.evaluator.Pixel_Accuracy_Class() + mIoU = self.evaluator.Mean_Intersection_over_Union() + FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() + self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) + self.writer.add_scalar('val/mIoU', mIoU, epoch) + self.writer.add_scalar('val/Acc', Acc, epoch) + self.writer.add_scalar('val/Acc_class', Acc_class, epoch) + self.writer.add_scalar('val/fwIoU', FWIoU, epoch) + print('Validation:') + print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) + print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) + print('Loss: %.3f' % test_loss) + + new_pred = mIoU + if new_pred > self.best_pred: + is_best = True + self.best_pred = new_pred + self.saver.save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'best_pred': self.best_pred, + }, is_best) + +def train(): + parser = argparse.ArgumentParser(description="PyTorch RFNet Training") + parser.add_argument('--depth', action="store_true", default=False, + help='training with depth image or not (default: False)') + parser.add_argument('--dataset', type=str, default='cityscapes', + choices=['citylostfound', 'cityscapes', 'cityrand', 'target', 'xrlab', 'e1', 'mapillary'], + help='dataset name (default: cityscapes)') + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='dataloader threads') + parser.add_argument('--base-size', type=int, default=1024, + help='base image size') + parser.add_argument('--crop-size', type=int, default=768, + help='crop image size') + parser.add_argument('--loss-type', type=str, default='ce', + choices=['ce', 'focal'], + help='loss func type (default: ce)') + # training hyper params + # parser.add_argument('--epochs', type=int, default=None, metavar='N', + # help='number of epochs to train (default: auto)') + parser.add_argument('--epochs', type=int, default=1, metavar='N', + help='number of epochs to train (default: auto)') + parser.add_argument('--start_epoch', type=int, default=0, + metavar='N', help='start epochs (default:0)') + parser.add_argument('--batch-size', type=int, default=None, + metavar='N', help='input batch size for \ + training (default: auto)') + parser.add_argument('--val-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--test-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--use-balanced-weights', action='store_true', default=False, + help='whether to use balanced weights (default: True)') + + parser.add_argument('--num-class', type=int, default=24, + help='number of training classes (default: 24') + # optimizer params + parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', + help='learning rate (default: auto)') + parser.add_argument('--lr-scheduler', type=str, default='cos', + choices=['poly', 'step', 'cos', 'inv'], + help='lr scheduler mode: (default: cos)') + parser.add_argument('--momentum', type=float, default=0.9, + metavar='M', help='momentum (default: 0.9)') + parser.add_argument('--weight-decay', type=float, default=2.5e-5, + metavar='M', help='w-decay (default: 5e-4)') + # cuda, seed and logging + parser.add_argument('--no-cuda', action='store_true', default= + False, help='disables CUDA training') + parser.add_argument('--gpu-ids', type=str, default='0', + help='use which gpu to train, must be a \ + comma-separated list of integers only (default=0)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + # checking point + parser.add_argument('--resume', type=str, default=None, + help='put the path to resuming file if needed') + parser.add_argument('--checkname', type=str, default=None, + help='set the checkpoint name') + # finetuning pre-trained models + parser.add_argument('--ft', action='store_true', default=True, + help='finetuning on a different dataset') + # evaluation option + parser.add_argument('--eval-interval', type=int, default=1, + help='evaluation interval (default: 1)') + parser.add_argument('--no-val', action='store_true', default=False, + help='skip validation during training') + + args = parser.parse_args() + args.cuda = not args.no_cuda and torch.cuda.is_available() + print(torch.cuda.is_available()) + if args.cuda: + try: + args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] + except ValueError: + raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') + + if args.epochs is None: + epoches = { + 'cityscapes': 200, + 'citylostfound': 200, + } + args.epochs = epoches[args.dataset.lower()] + + if args.batch_size is None: + args.batch_size = 4 * len(args.gpu_ids) + + if args.test_batch_size is None: + args.test_batch_size = args.batch_size + + if args.lr is None: + lrs = { + 'cityscapes': 0.0001, + 'citylostfound': 0.0001, + 'cityrand': 0.0001 + } + args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size + + + if args.checkname is None: + args.checkname = 'RFNet' + print(args) + torch.manual_seed(args.seed) + args.resume = "/home/lsq/ianvs_new/examples/robo/workspace/pcb-algorithm-test/test-algorithm/6441f8be-d809-11ec-ac65-3b30682caaa6/knowledgeable/1/seen_task/checkpoint_1653029539.6730478.pth" + trainer = Trainer(args, train_data=val_data, valid_data=val_data) + trainer.validation(0) + # print('Starting Epoch:', trainer.args.start_epoch) + # print('Total Epoches:', trainer.args.epochs) + # for epoch in range(trainer.args.start_epoch, trainer.args.epochs): + # if epoch == 0: + # trainer.validation(epoch) + # trainer.training(epoch) + # if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1): + # trainer.validation(epoch) + # + # trainer.writer.close() + +if __name__ == "__main__": + val_data = BaseDataSource(data_type="eval") + x, y = [], [] + with open("/home/lsq/ianvs_new/examples/robo/trainData_depth.txt", "r") as f: + for line in f.readlines()[-120:]: + lines = line.strip().split() + x.append(lines[:2]) + y.append(lines[-1]) + + val_data.x = x + val_data.y = y + + train() \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/utils/__init__.py b/examples/lifelong_learning/RFNet/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/lifelong_learning/RFNet/utils/calculate_weights.py b/examples/lifelong_learning/RFNet/utils/calculate_weights.py new file mode 100644 index 000000000..2c2c98211 --- /dev/null +++ b/examples/lifelong_learning/RFNet/utils/calculate_weights.py @@ -0,0 +1,29 @@ +import os +from tqdm import tqdm +import numpy as np +from mypath import Path + +def calculate_weigths_labels(dataset, dataloader, num_classes): + # Create an instance from the data loader + z = np.zeros((num_classes,)) + # Initialize tqdm + tqdm_batch = tqdm(dataloader) + print('Calculating classes weights') + for sample in tqdm_batch: + y = sample['label'] + y = y.detach().cpu().numpy() + mask = (y >= 0) & (y < num_classes) + labels = y[mask].astype(np.uint8) + count_l = np.bincount(labels, minlength=num_classes) + z += count_l + tqdm_batch.close() + total_frequency = np.sum(z) + class_weights = [] + for frequency in z: + class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) + class_weights.append(class_weight) + ret = np.array(class_weights) + classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy') + np.save(classes_weights_path, ret) + + return ret \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/utils/iouEval.py b/examples/lifelong_learning/RFNet/utils/iouEval.py new file mode 100644 index 000000000..de9558259 --- /dev/null +++ b/examples/lifelong_learning/RFNet/utils/iouEval.py @@ -0,0 +1,142 @@ +import torch + + +class iouEval: + + def __init__(self, nClasses, ignoreIndex=20): + + self.nClasses = nClasses + self.ignoreIndex = ignoreIndex if nClasses > ignoreIndex else -1 # if ignoreIndex is larger than nClasses, consider no ignoreIndex + self.reset() + + def reset(self): + classes = self.nClasses if self.ignoreIndex == -1 else self.nClasses - 1 + self.tp = torch.zeros(classes).double() + self.fp = torch.zeros(classes).double() + self.fn = torch.zeros(classes).double() + self.cdp_obstacle = torch.zeros(1).double() + self.tp_obstacle = torch.zeros(1).double() + self.idp_obstacle = torch.zeros(1).double() + self.tp_nonobstacle = torch.zeros(1).double() + # self.cdi = torch.zeros(1).double() + + def addBatch(self, x, y): # x=preds, y=targets + # sizes should be "batch_size x nClasses x H x W" + # cdi = 0 + + # print ("X is cuda: ", x.is_cuda) + # print ("Y is cuda: ", y.is_cuda) + + if (x.is_cuda or y.is_cuda): + x = x.cuda() + y = y.cuda() + + # if size is "batch_size x 1 x H x W" scatter to onehot + if (x.size(1) == 1): + x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3)) + if x.is_cuda: + x_onehot = x_onehot.cuda() + x_onehot.scatter_(1, x, 1).float() # dim index src 按照列用1替换0,索引为x + else: + x_onehot = x.float() + + if (y.size(1) == 1): + y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3)) + if y.is_cuda: + y_onehot = y_onehot.cuda() + y_onehot.scatter_(1, y, 1).float() + else: + y_onehot = y.float() + + if (self.ignoreIndex != -1): + ignores = y_onehot[:, self.ignoreIndex].unsqueeze(1) # 加一维 + x_onehot = x_onehot[:, :self.ignoreIndex] # ignoreIndex后的都不要 + y_onehot = y_onehot[:, :self.ignoreIndex] + else: + ignores = 0 + + + tpmult = x_onehot * y_onehot # times prediction and gt coincide is 1 + tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, + keepdim=True).squeeze() + fpmult = x_onehot * ( + 1 - y_onehot - ignores) # times prediction says its that class and gt says its not (subtracting cases when its ignore label!) + fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, + keepdim=True).squeeze() + fnmult = (1 - x_onehot) * (y_onehot) # times prediction says its not that class and gt says it is + fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, + keepdim=True).squeeze() + + self.tp += tp.double().cpu() + self.fp += fp.double().cpu() + self.fn += fn.double().cpu() + + cdp_obstacle = tpmult[:, 19].sum() # obstacle index 19 + tp_obstacle = y_onehot[:, 19].sum() + + idp_obstacle = (x_onehot[:, 19] - tpmult[:, 19]).sum() + tp_nonobstacle = (-1*y_onehot+1).sum() + + # for i in range(0, x.size(0)): + # if tpmult[i].sum()/(y_onehot[i].sum() + 1e-15) >= 0.5: + # cdi += 1 + + + self.cdp_obstacle += cdp_obstacle.double().cpu() + self.tp_obstacle += tp_obstacle.double().cpu() + self.idp_obstacle += idp_obstacle.double().cpu() + self.tp_nonobstacle += tp_nonobstacle.double().cpu() + # self.cdi += cdi.double().cpu() + + + + def getIoU(self): + num = self.tp + den = self.tp + self.fp + self.fn + 1e-15 + iou = num / den + iou_not_zero = list(filter(lambda x: x != 0, iou)) + # print(len(iou_not_zero)) + iou_mean = sum(iou_not_zero) / len(iou_not_zero) + tfp = self.tp + self.fp + 1e-15 + acc = num / tfp + acc_not_zero = list(filter(lambda x: x != 0, acc)) + acc_mean = sum(acc_not_zero) / len(acc_not_zero) + + return iou_mean, iou, acc_mean, acc # returns "iou mean", "iou per class" + + def getObstacleEval(self): + + pdr_obstacle = self.cdp_obstacle / (self.tp_obstacle+1e-15) + pfp_obstacle = self.idp_obstacle / (self.tp_nonobstacle+1e-15) + + return pdr_obstacle, pfp_obstacle + + +# Class for colors +class colors: + RED = '\033[31;1m' + GREEN = '\033[32;1m' + YELLOW = '\033[33;1m' + BLUE = '\033[34;1m' + MAGENTA = '\033[35;1m' + CYAN = '\033[36;1m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + ENDC = '\033[0m' + + +# Colored value output if colorized flag is activated. +def getColorEntry(val): + if not isinstance(val, float): + return colors.ENDC + if (val < .20): + return colors.RED + elif (val < .40): + return colors.YELLOW + elif (val < .60): + return colors.BLUE + elif (val < .80): + return colors.CYAN + else: + return colors.GREEN + diff --git a/examples/lifelong_learning/RFNet/utils/loss.py b/examples/lifelong_learning/RFNet/utils/loss.py new file mode 100644 index 000000000..6cde9a175 --- /dev/null +++ b/examples/lifelong_learning/RFNet/utils/loss.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn + +class SegmentationLosses(object): + def __init__(self, weight=None, size_average=True, batch_average=True, ignore_index=255, cuda=False): # ignore_index=255 + self.ignore_index = ignore_index + self.weight = weight + self.size_average = size_average + self.batch_average = batch_average + self.cuda = cuda + + def build_loss(self, mode='ce'): + """Choices: ['ce' or 'focal']""" + if mode == 'ce': + return self.CrossEntropyLoss + elif mode == 'focal': + return self.FocalLoss + else: + raise NotImplementedError + + def CrossEntropyLoss(self, logit, target): + n, c, h, w = logit.size() + #criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, + #size_average=self.size_average) + criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=self.ignore_index) + if self.cuda: + criterion = criterion.cuda() + + loss = criterion(logit, target.long()) + + if self.batch_average: + loss /= n + + return loss + + def FocalLoss(self, logit, target, gamma=2, alpha=0.5): + n, c, h, w = logit.size() + criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, + size_average=self.size_average) + if self.cuda: + criterion = criterion.cuda() + + logpt = -criterion(logit, target.long()) + pt = torch.exp(logpt) + if alpha is not None: + logpt *= alpha + loss = -((1 - pt) ** gamma) * logpt + + if self.batch_average: + loss /= n + + return loss + +if __name__ == "__main__": + loss = SegmentationLosses(cuda=True) + a = torch.rand(1, 3, 7, 7).cuda() + b = torch.rand(1, 7, 7).cuda() + print(loss.CrossEntropyLoss(a, b).item()) + print(loss.FocalLoss(a, b, gamma=0, alpha=None).item()) + print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item()) + + + + diff --git a/examples/lifelong_learning/RFNet/utils/lr_scheduler.py b/examples/lifelong_learning/RFNet/utils/lr_scheduler.py new file mode 100644 index 000000000..471240282 --- /dev/null +++ b/examples/lifelong_learning/RFNet/utils/lr_scheduler.py @@ -0,0 +1,70 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## ECE Department, Rutgers University +## Email: zhang.hang@rutgers.edu +## Copyright (c) 2017 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import math + +class LR_Scheduler(object): + """Learning Rate Scheduler + + Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` + + Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` + + Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` + + Args: + args: + :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), + :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, + :attr:`args.lr_step` + + iters_per_epoch: number of iterations per epoch + """ + def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, + lr_step=0, warmup_epochs=0): + self.mode = mode + print('Using {} LR Scheduler!'.format(self.mode)) + self.lr = base_lr + if mode == 'step': + assert lr_step + self.lr_step = lr_step + self.iters_per_epoch = iters_per_epoch + self.N = num_epochs * iters_per_epoch + self.epoch = -1 + self.warmup_iters = warmup_epochs * iters_per_epoch + + def __call__(self, optimizer, i, epoch, best_pred): + T = epoch * self.iters_per_epoch + i + if self.mode == 'cos': + lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) + elif self.mode == 'poly': + lr = self.lr * pow((1 - 1.0 * T / self.N), 2) + elif self.mode == 'step': + lr = self.lr * (0.1 ** (epoch // self.lr_step)) + else: + raise NotImplemented + # warm up lr schedule + if self.warmup_iters > 0 and T < self.warmup_iters: + lr = lr * 1.0 * T / self.warmup_iters + if epoch > self.epoch: + print('\n=>Epoches %i, learning rate = %.4f, \ + previous best = %.4f' % (epoch, lr, best_pred)) + self.epoch = epoch + assert lr >= 0 + self._adjust_learning_rate(optimizer, lr) + + def _adjust_learning_rate(self, optimizer, lr): + if len(optimizer.param_groups) == 1: + optimizer.param_groups[0]['lr'] = lr * 4 + else: + # enlarge the lr at the head + optimizer.param_groups[0]['lr'] = lr * 4 + for i in range(1, len(optimizer.param_groups)): + optimizer.param_groups[i]['lr'] = lr diff --git a/examples/lifelong_learning/RFNet/utils/metrics.py b/examples/lifelong_learning/RFNet/utils/metrics.py new file mode 100644 index 000000000..61cbab472 --- /dev/null +++ b/examples/lifelong_learning/RFNet/utils/metrics.py @@ -0,0 +1,145 @@ +import numpy as np + + +class Evaluator(object): + def __init__(self, num_class): + self.num_class = num_class + self.confusion_matrix = np.zeros((self.num_class,)*2) # shape:(num_class, num_class) + + def Pixel_Accuracy(self): + Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() + return Acc + + def Pixel_Accuracy_Class_Curb(self): + Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) + print('-----------Acc of each classes-----------') + print("road : %.6f" % (Acc[0] * 100.0), "%\t") + print("sidewalk : %.6f" % (Acc[1] * 100.0), "%\t") + Acc = np.nanmean(Acc[:2]) + return Acc + + + def Pixel_Accuracy_Class(self): + Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) + print('-----------Acc of each classes-----------') + print("road : %.6f" % (Acc[0] * 100.0), "%\t") + print("sidewalk : %.6f" % (Acc[1] * 100.0), "%\t") + print("building : %.6f" % (Acc[2] * 100.0), "%\t") + print("wall : %.6f" % (Acc[3] * 100.0), "%\t") + print("fence : %.6f" % (Acc[4] * 100.0), "%\t") + print("pole : %.6f" % (Acc[5] * 100.0), "%\t") + print("traffic light: %.6f" % (Acc[6] * 100.0), "%\t") + print("traffic sign : %.6f" % (Acc[7] * 100.0), "%\t") + print("vegetation : %.6f" % (Acc[8] * 100.0), "%\t") + print("terrain : %.6f" % (Acc[9] * 100.0), "%\t") + print("sky : %.6f" % (Acc[10] * 100.0), "%\t") + print("person : %.6f" % (Acc[11] * 100.0), "%\t") + print("rider : %.6f" % (Acc[12] * 100.0), "%\t") + print("car : %.6f" % (Acc[13] * 100.0), "%\t") + print("truck : %.6f" % (Acc[14] * 100.0), "%\t") + print("bus : %.6f" % (Acc[15] * 100.0), "%\t") + print("train : %.6f" % (Acc[16] * 100.0), "%\t") + print("motorcycle : %.6f" % (Acc[17] * 100.0), "%\t") + print("bicycle : %.6f" % (Acc[18] * 100.0), "%\t") + print("dynamic : %.6f" % (Acc[19] * 100.0), "%\t") + print("stair : %.6f" % (Acc[20] * 100.0), "%\t") + if self.num_class == 20: + print("small obstacles: %.6f" % (Acc[19] * 100.0), "%\t") + Acc = np.nanmean(Acc) + return Acc + + def Mean_Intersection_over_Union(self): + MIoU = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + + # print MIoU of each class + print('-----------IoU of each classes-----------') + print("road : %.6f" % (MIoU[0] * 100.0), "%\t") + print("sidewalk : %.6f" % (MIoU[1] * 100.0), "%\t") + print("building : %.6f" % (MIoU[2] * 100.0), "%\t") + print("wall : %.6f" % (MIoU[3] * 100.0), "%\t") + print("fence : %.6f" % (MIoU[4] * 100.0), "%\t") + print("pole : %.6f" % (MIoU[5] * 100.0), "%\t") + print("traffic light: %.6f" % (MIoU[6] * 100.0), "%\t") + print("traffic sign : %.6f" % (MIoU[7] * 100.0), "%\t") + print("vegetation : %.6f" % (MIoU[8] * 100.0), "%\t") + print("terrain : %.6f" % (MIoU[9] * 100.0), "%\t") + print("sky : %.6f" % (MIoU[10] * 100.0), "%\t") + print("person : %.6f" % (MIoU[11] * 100.0), "%\t") + print("rider : %.6f" % (MIoU[12] * 100.0), "%\t") + print("car : %.6f" % (MIoU[13] * 100.0), "%\t") + print("truck : %.6f" % (MIoU[14] * 100.0), "%\t") + print("bus : %.6f" % (MIoU[15] * 100.0), "%\t") + print("train : %.6f" % (MIoU[16] * 100.0), "%\t") + print("motorcycle : %.6f" % (MIoU[17] * 100.0), "%\t") + print("bicycle : %.6f" % (MIoU[18] * 100.0), "%\t") + print("dynamic : %.6f" % (MIoU[19] * 100.0), "%\t") + print("stair : %.6f" % (MIoU[20] * 100.0), "%\t") + if self.num_class == 20: + print("small obstacles: %.6f" % (MIoU[19] * 100.0), "%\t") + + MIoU = np.nanmean(MIoU) + return MIoU + + def Mean_Intersection_over_Union_Curb(self): + MIoU = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + + # print MIoU of each class + print('-----------IoU of each classes-----------') + print("road : %.6f" % (MIoU[0] * 100.0), "%\t") + print("sidewalk : %.6f" % (MIoU[1] * 100.0), "%\t") + + if self.num_class == 20: + print("small obstacles: %.6f" % (MIoU[19] * 100.0), "%\t") + + MIoU = np.nanmean(MIoU[:2]) + return MIoU + + def Frequency_Weighted_Intersection_over_Union(self): + freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) + iu = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + + FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() + CFWIoU = freq[freq > 0] * iu[freq > 0] + print('-----------FWIoU of each classes-----------') + print("road : %.6f" % (CFWIoU[0] * 100.0), "%\t") + print("sidewalk : %.6f" % (CFWIoU[1] * 100.0), "%\t") + + return FWIoU + + def Frequency_Weighted_Intersection_over_Union_Curb(self): + freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) + iu = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + + # FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() + CFWIoU = freq[freq > 0] * iu[freq > 0] + print('-----------FWIoU of each classes-----------') + print("road : %.6f" % (CFWIoU[0] * 100.0), "%\t") + print("sidewalk : %.6f" % (CFWIoU[1] * 100.0), "%\t") + + return np.nanmean(CFWIoU[:2]) + + def _generate_matrix(self, gt_image, pre_image): + mask = (gt_image >= 0) & (gt_image < self.num_class) + label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] + count = np.bincount(label, minlength=self.num_class**2) + confusion_matrix = count.reshape(self.num_class, self.num_class) + return confusion_matrix + + def add_batch(self, gt_image, pre_image): + assert gt_image.shape == pre_image.shape + self.confusion_matrix += self._generate_matrix(gt_image, pre_image) + + def reset(self): + self.confusion_matrix = np.zeros((self.num_class,) * 2) + + + + diff --git a/examples/lifelong_learning/RFNet/utils/saver.py b/examples/lifelong_learning/RFNet/utils/saver.py new file mode 100644 index 000000000..03866432e --- /dev/null +++ b/examples/lifelong_learning/RFNet/utils/saver.py @@ -0,0 +1,68 @@ +import os +import time +import shutil +import tempfile +import torch +from collections import OrderedDict +import glob + +class Saver(object): + + def __init__(self, args): + self.args = args + self.directory = os.path.join('/tmp', args.dataset, args.checkname) + self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) + run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 + + self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) + if not os.path.exists(self.experiment_dir): + os.makedirs(self.experiment_dir) + + def save_checkpoint(self, state, is_best): # filename from .pth.tar change to .pth? + """Saves checkpoint to disk""" + filename = f'checkpoint_{time.time()}.pth' + checkpoint_path = os.path.join(self.experiment_dir, filename) + torch.save(state, checkpoint_path) + if is_best: + best_pred = state['best_pred'] + with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: + f.write(str(best_pred)) + if self.runs: + previous_miou = [0.0] + for run in self.runs: + run_id = run.split('_')[-1] + path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') + if os.path.exists(path): + with open(path, 'r') as f: + miou = float(f.readline()) + previous_miou.append(miou) + else: + continue + max_miou = max(previous_miou) + if best_pred > max_miou: + checkpoint_path_best = os.path.join(self.directory, 'model_best.pth') + shutil.copyfile(checkpoint_path, checkpoint_path_best) + checkpoint_path = checkpoint_path_best + else: + checkpoint_path_best = os.path.join(self.directory, 'model_best.pth') + shutil.copyfile(checkpoint_path, checkpoint_path_best) + checkpoint_path = checkpoint_path_best + + return checkpoint_path + + def save_experiment_config(self): + logfile = os.path.join(self.experiment_dir, 'parameters.txt') + log_file = open(logfile, 'w') + p = OrderedDict() + p['datset'] = self.args.dataset + # p['out_stride'] = self.args.out_stride + p['lr'] = self.args.lr + p['lr_scheduler'] = self.args.lr_scheduler + p['loss_type'] = self.args.loss_type + p['epoch'] = self.args.epochs + p['base_size'] = self.args.base_size + p['crop_size'] = self.args.crop_size + + for key, val in p.items(): + log_file.write(key + ':' + str(val) + '\n') + log_file.close() \ No newline at end of file diff --git a/examples/lifelong_learning/RFNet/utils/summaries.py b/examples/lifelong_learning/RFNet/utils/summaries.py new file mode 100644 index 000000000..04bcdb822 --- /dev/null +++ b/examples/lifelong_learning/RFNet/utils/summaries.py @@ -0,0 +1,39 @@ +import os +import torch +from torchvision.utils import make_grid +# from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter +from dataloaders.utils import decode_seg_map_sequence + +class TensorboardSummary(object): + def __init__(self, directory): + self.directory = directory + + def create_summary(self): + writer = SummaryWriter(log_dir=os.path.join(self.directory)) + return writer + + def visualize_image(self, writer, dataset, image, target, output, global_step, depth=None): + if depth is None: + grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) + writer.add_image('Image', grid_image, global_step) + + grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), + dataset=dataset), 3, normalize=False, range=(0, 255)) + writer.add_image('Predicted label', grid_image, global_step) + grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), + dataset=dataset), 3, normalize=False, range=(0, 255)) + writer.add_image('Groundtruth label', grid_image, global_step) + else: + grid_image = make_grid(image[:3].clone().cpu().data, 4, normalize=True) + writer.add_image('Image', grid_image, global_step) + + grid_image = make_grid(depth[:3].clone().cpu().data, 4, normalize=True) # normalize=False? + writer.add_image('Depth', grid_image, global_step) + + grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), + dataset=dataset), 4, normalize=False, range=(0, 255)) + writer.add_image('Predicted label', grid_image, global_step) + grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), + dataset=dataset), 4, normalize=False, range=(0, 255)) + writer.add_image('Groundtruth label', grid_image, global_step) \ No newline at end of file diff --git a/lib/sedna/algorithms/__init__.py b/lib/sedna/algorithms/__init__.py index 12f6e7b1b..5b44f8023 100644 --- a/lib/sedna/algorithms/__init__.py +++ b/lib/sedna/algorithms/__init__.py @@ -17,3 +17,7 @@ from . import multi_task_learning from . import unseen_task_detect from . import reid +from . import knowledge_management +from . import seen_task_learning +from . import unseen_task_detection +from . import unseen_task_processing diff --git a/lib/sedna/algorithms/knowledge_management/__init__.py b/lib/sedna/algorithms/knowledge_management/__init__.py new file mode 100644 index 000000000..88dfba917 --- /dev/null +++ b/lib/sedna/algorithms/knowledge_management/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cloud_knowledge_management import CloudKnowledgeManagement +from .edge_knowledge_management import EdgeKnowledgeManagement +from . import task_update_decision +from . import task_evaluation \ No newline at end of file diff --git a/lib/sedna/algorithms/knowledge_management/cloud_knowledge_management.py b/lib/sedna/algorithms/knowledge_management/cloud_knowledge_management.py new file mode 100644 index 000000000..6dfe6d9a0 --- /dev/null +++ b/lib/sedna/algorithms/knowledge_management/cloud_knowledge_management.py @@ -0,0 +1,141 @@ +import os +import time +import tempfile + +from sedna.common.log import LOGGER +from sedna.common.class_factory import ClassType, ClassFactory +from sedna.common.file_ops import FileOps +from sedna.common.constant import KBResourceConstant + +__all__ = ('CloudKnowledgeManagement', ) + + +@ClassFactory.register(ClassType.KM) +class CloudKnowledgeManagement: + """ + Manage task processing, kb update and task deployment, etc., at cloud. + + Parameters: + ---------- + config: Dict + parameters to initialize an object + """ + + def __init__(self, config, **kwargs): + self.last_task_index = kwargs.get("last_task_index", None) + self.cloud_output_url = config.get( + "cloud_output_url", "/tmp") + self.task_index = FileOps.join_path( + self.cloud_output_url, config["task_index"]) + self.local_task_index_url = KBResourceConstant.KB_INDEX_NAME.value + + task_evaluation = kwargs.get("task_evaluation") or {} + self.task_evaluation = task_evaluation.get( + "method", "TaskEvaluationDefault") + self.task_evaluation_param = task_evaluation.get("param", {}) + + self.estimator = kwargs.get("estimator") or None + self.log = LOGGER + + self.seen_task_key = KBResourceConstant.SEEN_TASK.value + self.unseen_task_key = KBResourceConstant.UNSEEN_TASK.value + self.task_group_key = KBResourceConstant.TASK_GROUPS.value + self.extractor_key = KBResourceConstant.EXTRACTOR.value + + def update_kb(self, task_index, kb_server): + if isinstance(task_index, str): + task_index = FileOps.load(task_index) + + seen_task_index = task_index.get(self.seen_task_key) + unseen_task_index = task_index.get(self.unseen_task_key) + + seen_extractor, seen_task_groups = self._save_task_index( + seen_task_index, kb_server, task_type=self.seen_task_key) + unseen_extractor, unseen_task_groups = self._save_task_index( + unseen_task_index, kb_server, task_type=self.unseen_task_key) + + task_info = { + self.seen_task_key: { + self.task_group_key: seen_task_groups, + self.extractor_key: seen_extractor + }, + self.unseen_task_key: { + self.task_group_key: unseen_task_groups, + self.extractor_key: unseen_extractor + }, + "create_time": str(time.time()) + } + + fd, name = tempfile.mkstemp() + FileOps.dump(task_info, name) + + index_file = kb_server.update_db(name) + if not index_file: + self.log.error("KB update Fail !") + index_file = name + + return FileOps.upload(index_file, self.task_index) + + def _save_task_index(self, task_index, kb_server, task_type="seen_task"): + extractor = task_index[self.extractor_key] + if isinstance(extractor, str): + extractor = FileOps.load(extractor) + task_groups = task_index[self.task_group_key] + + model_upload_key = {} + for task_group in task_groups: + model_file = task_group.model.model + save_model = FileOps.join_path( + self.cloud_output_url, task_type, + os.path.basename(model_file) + ) + if model_file not in model_upload_key: + model_upload_key[model_file] = FileOps.upload( + model_file, save_model) + + model_file = model_upload_key[model_file] + + try: + model = kb_server.upload_file(save_model) + except Exception as err: + self.log.error( + f"Upload task model of {model_file} fail: {err}" + ) + model = FileOps.join_path( + self.cloud_output_url, + task_type, + os.path.basename(model_file)) + + task_group.model.model = model + + for _task in task_group.tasks: + _task.model = FileOps.join_path( + self.cloud_output_url, task_type, os.path.basename(model_file)) + sample_dir = FileOps.join_path( + self.cloud_output_url, task_type, + f"{_task.samples.data_type}_{_task.entry}.sample") + task_group.samples.save(sample_dir) + + try: + sample_dir = kb_server.upload_file(sample_dir) + except Exception as err: + self.log.error( + f"Upload task samples of {_task.entry} fail: {err}") + _task.samples.data_url = sample_dir + + save_extractor = FileOps.join_path( + self.cloud_output_url, task_type, + f"{task_type}_{KBResourceConstant.TASK_EXTRACTOR_NAME.value}" + ) + extractor = FileOps.dump(extractor, save_extractor) + try: + extractor = kb_server.upload_file(extractor) + except Exception as err: + self.log.error(f"Upload task extractor fail: {err}") + + return extractor, task_groups + + def evaluate_tasks(self, tasks_detail, **kwargs): + method_cls = ClassFactory.get_cls( + ClassType.KM, self.task_evaluation)(**self.task_evaluation_param) + return method_cls(tasks_detail, **kwargs) diff --git a/lib/sedna/algorithms/knowledge_management/edge_knowledge_management.py b/lib/sedna/algorithms/knowledge_management/edge_knowledge_management.py new file mode 100644 index 000000000..120412c16 --- /dev/null +++ b/lib/sedna/algorithms/knowledge_management/edge_knowledge_management.py @@ -0,0 +1,185 @@ +import os +import time +import tempfile +import threading + +from sedna.common.log import LOGGER +from sedna.common.config import Context +from sedna.common.class_factory import ClassType, ClassFactory +from sedna.common.file_ops import FileOps +from sedna.common.constant import KBResourceConstant, K8sResourceKindStatus + +__all__ = ('EdgeKnowledgeManagement', ) + + +@ClassFactory.register(ClassType.KM) +class EdgeKnowledgeManagement: + """ + Manage task processing at the edge. + + Parameters: + ---------- + config: Dict + parameters to initialize an object + estimator: Instance + An instance with the high-level API that greatly simplifies + machine learning programming. Estimators encapsulate training, + evaluation, prediction, and exporting for your model. + """ + + def __init__(self, config, estimator, **kwargs): + self.edge_output_url = config.get( + "edge_output_url") or "/tmp/edge_output_url" + self.task_index = FileOps.join_path( + self.edge_output_url, config["task_index"]) + self.estimator = estimator + self.log = LOGGER + + self.seen_task_key = KBResourceConstant.SEEN_TASK.value + self.unseen_task_key = KBResourceConstant.UNSEEN_TASK.value + self.task_group_key = KBResourceConstant.TASK_GROUPS.value + self.extractor_key = KBResourceConstant.EXTRACTOR.value + + ModelLoadingThread(self.task_index).start() + + def update_kb(self, task_index_url): + if isinstance(task_index_url, str): + try: + task_index = FileOps.load(task_index_url) + except Exception as err: + self.log.error(f"{err}") + self.log.error( + "Load task index failed. KB deployment to the edge failed.") + return None + else: + task_index = task_index_url + + FileOps.clean_folder(self.edge_output_url) + seen_task_index = task_index.get(self.seen_task_key) + unseen_task_index = task_index.get(self.unseen_task_key) + + seen_extractor, seen_task_groups = self._save_task_index( + seen_task_index, task_type=self.seen_task_key) + unseen_extractor, unseen_task_groups = self._save_task_index( + unseen_task_index, task_type=self.unseen_task_key) + + task_info = { + self.seen_task_key: { + self.task_group_key: seen_task_groups, + self.extractor_key: seen_extractor + }, + self.unseen_task_key: { + self.task_group_key: unseen_task_groups, + self.extractor_key: unseen_extractor + }, + "created_time": task_index.get("created_time", str(time.time())) + } + + fd, name = tempfile.mkstemp() + FileOps.dump(task_info, name) + return FileOps.upload(name, self.task_index) + + def _save_task_index(self, task_index, task_type="seen_task"): + extractor = task_index[self.extractor_key] + if isinstance(extractor, str): + extractor = FileOps.load(extractor) + task_groups = task_index[self.task_group_key] + + model_upload_key = {} + for task in task_groups: + model_file = task.model.model + save_model = FileOps.join_path( + self.edge_output_url, task_type, + os.path.basename(model_file) + ) + if model_file not in model_upload_key: + model_upload_key[model_file] = FileOps.download( + model_file, save_model) + model_file = model_upload_key[model_file] + + task.model.model = save_model + + for _task in task.tasks: + _task.model = FileOps.join_path( + self.edge_output_url, task_type, os.path.basename(model_file)) + sample_dir = FileOps.join_path( + self.edge_output_url, task_type, + f"{_task.samples.data_type}_{_task.entry}.sample") + _task.samples.data_url = FileOps.download( + _task.samples.data_url, sample_dir) + + save_extractor = FileOps.join_path( + self.edge_output_url, task_type, + KBResourceConstant.TASK_EXTRACTOR_NAME.value + ) + extractor = FileOps.dump(extractor, save_extractor) + + return extractor, task_groups + + def save_unseen_samples(self, samples, post_process): + # TODO: save unseen samples to specified directory. + if callable(post_process): + samples = post_process(samples) + + fd, name = tempfile.mkstemp() + + FileOps.dump(samples, name) + unseen_save_url = FileOps.join_path( + Context.get_parameters( + "unseen_save_url", + self.edge_output_url), + f"unseen_samples_{time.time()}.pkl") + return FileOps.upload(name, unseen_save_url) + +class ModelLoadingThread(threading.Thread): + """Hot task index loading with multithread support""" + MODEL_MANIPULATION_SEM = threading.Semaphore(1) + + def __init__(self, + task_index, + callback=None + ): + self.run_flag = True + hot_update_task_index = Context.get_parameters("MODEL_URLS") + if not hot_update_task_index: + LOGGER.error("As `MODEL_URLS` unset a value, skipped") + self.run_flag = False + if not FileOps.exists(task_index): + LOGGER.error("As local task index has not been loaded, skipped") + self.run_flag = False + model_check_time = int(Context.get_parameters( + "MODEL_POLL_PERIOD_SECONDS", "60") + ) + if model_check_time < 1: + LOGGER.warning("Catch an abnormal value in " + "`MODEL_POLL_PERIOD_SECONDS`, fallback with 60") + model_check_time = 60 + self.hot_update_task_index = hot_update_task_index + self.check_time = model_check_time + self.task_index = task_index + self.callback = callback + super(ModelLoadingThread, self).__init__() + + def run(self): + while self.run_flag: + time.sleep(self.check_time) + tmp_task_index = FileOps.load(self.hot_update_task_index) + latest_version = tmp_task_index.get("create_time") + current_version = FileOps.load(self.task_index).get("create_time") + if latest_version == current_version: + continue + current_version = latest_version + with self.MODEL_MANIPULATION_SEM: + LOGGER.info(f"Update model start with version {current_version}") + try: + FileOps.dump(tmp_task_index, self.task_index) + status = K8sResourceKindStatus.COMPLETED.value + LOGGER.info(f"Update task index complete " + f"with version {self.version}") + except Exception as e: + LOGGER.error(f"fail to update task index: {e}") + status = K8sResourceKindStatus.FAILED.value + if self.callback: + self.callback( + task_info=None, status=status, kind="deploy" + ) diff --git a/lib/sedna/algorithms/knowledge_management/task_evaluation/__init__.py b/lib/sedna/algorithms/knowledge_management/task_evaluation/__init__.py new file mode 100644 index 000000000..61d5b5cd3 --- /dev/null +++ b/lib/sedna/algorithms/knowledge_management/task_evaluation/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import task_evaluation \ No newline at end of file diff --git a/lib/sedna/algorithms/knowledge_management/task_evaluation/task_evaluation.py b/lib/sedna/algorithms/knowledge_management/task_evaluation/task_evaluation.py new file mode 100644 index 000000000..a3eb02ccd --- /dev/null +++ b/lib/sedna/algorithms/knowledge_management/task_evaluation/task_evaluation.py @@ -0,0 +1,80 @@ +from sedna.common.config import Context +from sedna.common.log import LOGGER +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ('TaskEvaluationDefault', ) + + +@ClassFactory.register(ClassType.KM) +class TaskEvaluationDefault: + """ + Evaluated the performance of each seen task and filter seen tasks + based on defined rules. + + Parameters + ---------- + estimator: Instance + An instance with the high-level API that greatly simplifies + machine learning programming. Estimators encapsulate training, + evaluation, prediction, and exporting for your model. + """ + + def __init__(self, **kwargs): + self.log = LOGGER + + def __call__(self, tasks_detail, **kwargs): + """ + Parameters + ---------- + tasks_detail: List[Task] + output of module task_update_decision, consisting of results of evaluation. + metrics : function / str + Metrics to assess performance on the task by given prediction. + metrics_param : Dict + parameter for metrics function. + kwargs: Dict + parameters for `estimator` evaluate. + + Returns + ------- + drop_task: List[str] + names of the tasks which will not to be deployed to the edge. + """ + + self.model_filter_operator = Context.get_parameters("operator", ">") + self.model_threshold = float( + Context.get_parameters( + "model_threshold", 0.1)) + + drop_tasks = [] + + operator_map = { + ">": lambda x, y: x > y, + "<": lambda x, y: x < y, + "=": lambda x, y: x == y, + ">=": lambda x, y: x >= y, + "<=": lambda x, y: x <= y, + } + if self.model_filter_operator not in operator_map: + self.log.warn( + f"operator {self.model_filter_operator} use to " + f"compare is not allow, set to <" + ) + self.model_filter_operator = "<" + operator_func = operator_map[self.model_filter_operator] + + for detail in tasks_detail: + scores = detail.scores + entry = detail.entry + self.log.info(f"{entry} scores: {scores}") + if any(map(lambda x: operator_func(float(x), + self.model_threshold), + scores.values())): + self.log.warn( + f"{entry} will not be deploy because all " + f"scores {self.model_filter_operator} {self.model_threshold}") + drop_tasks.append(entry) + continue + drop_task = ",".join(drop_tasks) + + return drop_task diff --git a/lib/sedna/algorithms/knowledge_management/task_update_decision/__init__.py b/lib/sedna/algorithms/knowledge_management/task_update_decision/__init__.py new file mode 100644 index 000000000..4230f27b9 --- /dev/null +++ b/lib/sedna/algorithms/knowledge_management/task_update_decision/__init__.py @@ -0,0 +1,2 @@ + +from . import task_update_decision \ No newline at end of file diff --git a/lib/sedna/algorithms/knowledge_management/task_update_decision/task_update_decision.py b/lib/sedna/algorithms/knowledge_management/task_update_decision/task_update_decision.py new file mode 100644 index 000000000..882862cf4 --- /dev/null +++ b/lib/sedna/algorithms/knowledge_management/task_update_decision/task_update_decision.py @@ -0,0 +1,110 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Divide multiple tasks based on data + +Parameters +---------- +samples: Train data, see `sedna.datasources.BaseDataSource` for more detail. + +Returns +------- +tasks: All tasks based on training data. +task_extractor: Model with a method to predicting target tasks +""" + +import time + +from sedna.common.file_ops import FileOps +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassType, ClassFactory +from sedna.algorithms.seen_task_learning.artifact import Task + +__all__ = ('UpdateStrategyDefault', ) + + +@ClassFactory.register(ClassType.KM) +class UpdateStrategyDefault: + """ + Decide processing strategies for different tasks + + Parameters + ---------- + task_index: str or Dict + """ + + def __init__(self, task_index, **kwargs): + if isinstance(task_index, str): + task_index = FileOps.load(task_index) + self.task_index = task_index + + def __call__(self, samples, task_type): + """ + Parameters + ---------- + samples: BaseDataSource + seen task samples or unseen task samples to be processed. + task_type: str + "seen_task" or "unseen_task". + See sedna.common.constant.KBResourceConstant for more details. + + Returns + ------- + self.tasks: List[Task] + tasks to be processed. + task_update_strategies: Dict + strategies to process each task. + """ + + if task_type == "seen_task": + task_index = self.task_index["seen_task"] + else: + task_index = self.task_index["unseen_task"] + + self.extractor = task_index["extractor"] + task_groups = task_index["task_groups"] + + tasks = [task_group.tasks[0] for task_group in task_groups] + + task_update_strategies = {} + for task in tasks: + task_update_strategies[task.entry] = { + "raw_data_update": None, + "target_model_update": None, + "task_attr_update": None, + } + + x_data = samples.x + y_data = samples.y + d_type = samples.data_type + + for task in tasks: + origin = task.meta_attr + _x = [x for x in x_data if origin in x[0]] + _y = [y for y in y_data if origin in y] + + task_df = BaseDataSource(data_type=d_type) + task_df.x = _x + task_df.y = _y + + task.samples = task_df + + task_update_strategies[task.entry] = { + "raw_data_update": samples, + "target_model_update": samples, + "task_attr_update": samples + } + + return tasks, task_update_strategies diff --git a/lib/sedna/algorithms/seen_task_learning/__init__.py b/lib/sedna/algorithms/seen_task_learning/__init__.py new file mode 100644 index 000000000..3e0d18ac8 --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# train +from . import task_definition +from . import task_relation_discover + +# inference +from . import task_remodeling +from . import task_allocation + +# result integrate +from . import inference_integrate + +from .seen_task_learning import SeenTaskLearning diff --git a/lib/sedna/algorithms/seen_task_learning/artifact.py b/lib/sedna/algorithms/seen_task_learning/artifact.py new file mode 100644 index 000000000..fd7c1e9cd --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/artifact.py @@ -0,0 +1,45 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +__all__ = ('Task', 'TaskGroup', 'Model') + + +class Task: + def __init__(self, entry, samples, meta_attr=None): + self.entry = entry + self.samples = samples + self.meta_attr = meta_attr + self.test_samples = None # assign on task definition and use in TRD + self.model = None # assign on running + self.result = None # assign on running + + +class TaskGroup: + + def __init__(self, entry, tasks: List[Task]): + self.entry = entry + self.tasks = tasks + self.samples = None # assign with task_relation_discover algorithms + self.model = None # assign on train + + +class Model: + def __init__(self, index: int, entry, model, result): + self.index = index # integer + self.entry = entry + self.model = model + self.result = result + self.meta_attr = None # assign on running diff --git a/lib/sedna/algorithms/seen_task_learning/inference_integrate/__init__.py b/lib/sedna/algorithms/seen_task_learning/inference_integrate/__init__.py new file mode 100644 index 000000000..3ed1b4995 --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/inference_integrate/__init__.py @@ -0,0 +1,3 @@ + +from . import inference_integrate + diff --git a/lib/sedna/algorithms/seen_task_learning/inference_integrate/inference_integrate.py b/lib/sedna/algorithms/seen_task_learning/inference_integrate/inference_integrate.py new file mode 100644 index 000000000..949f8c8b1 --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/inference_integrate/inference_integrate.py @@ -0,0 +1,57 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Integrate the inference results of all related tasks +""" + +from typing import List + +import numpy as np + +from sedna.common.class_factory import ClassFactory, ClassType + +from ..artifact import Task + +__all__ = ('DefaultInferenceIntegrate', ) + + +@ClassFactory.register(ClassType.STP) +class DefaultInferenceIntegrate: + """ + Default calculation algorithm for inference integration + + Parameters + ---------- + models: All models used for sample inference + """ + + def __init__(self, models: list, **kwargs): + self.models = models + + def __call__(self, tasks: List[Task]): + """ + Parameters + ---------- + tasks: All tasks with sample result + + Returns + ------- + result: minimum result + """ + res = {} + for task in tasks: + res.update(dict(zip(task.samples.inx, task.result))) + return np.array([z[1] + for z in sorted(res.items(), key=lambda x: x[0])]) diff --git a/lib/sedna/algorithms/seen_task_learning/seen_task_learning.py b/lib/sedna/algorithms/seen_task_learning/seen_task_learning.py new file mode 100644 index 000000000..d90f00cdc --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/seen_task_learning.py @@ -0,0 +1,578 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiple task transfer learning algorithms""" + +import json +import time + +import pandas as pd +from sklearn import metrics as sk_metrics + +from sedna.datasources import BaseDataSource +from sedna.backend import set_backend +from sedna.common.log import LOGGER +from sedna.common.file_ops import FileOps +from sedna.common.config import Context +from sedna.common.constant import KBResourceConstant +from sedna.common.class_factory import ClassFactory, ClassType + +from .artifact import Model, Task, TaskGroup + +__all__ = ('SeenTaskLearning',) + + +class SeenTaskLearning: + """ + An auto machine learning framework for edge-cloud multitask learning + + See Also + -------- + Train: Data + Estimator -> Task Definition -> Task Relationship Discovery + -> Feature Engineering -> Training + Inference: Data -> Task Allocation -> Feature Engineering + -> Task Remodeling -> Inference + + Parameters + ---------- + estimator : Instance + An instance with the high-level API that greatly simplifies + machine learning programming. Estimators encapsulate training, + evaluation, prediction, and exporting for your model. + task_definition : Dict + Divide multiple tasks based on data, + see `task_jobs.task_definition` for more detail. + task_relationship_discovery : Dict + Discover relationships between all tasks, see + `task_jobs.task_relationship_discovery` for more detail. + seen_task_allocation : Dict + Mining tasks of inference sample, + see `task_jobs.task_mining` for more detail. + task_remodeling : Dict + Remodeling tasks based on their relationships, + see `task_jobs.task_remodeling` for more detail. + inference_integrate : Dict + Integrate the inference results of all related + tasks, see `task_jobs.inference_integrate` for more detail. + + Examples + -------- + >>> from xgboost import XGBClassifier + >>> from sedna.algorithms.multi_task_learning import MulTaskLearning + >>> estimator = XGBClassifier(objective="binary:logistic") + >>> task_definition = { + "method": "TaskDefinitionByDataAttr", + "param": {"attribute": ["season", "city"]} + } + >>> task_relationship_discovery = { + "method": "DefaultTaskRelationDiscover", "param": {} + } + >>> seen_task_allocation = { + "method": "TaskAllocationByDataAttr", + "param": {"attribute": ["season", "city"]} + } + >>> task_remodeling = None + >>> inference_integrate = { + "method": "DefaultInferenceIntegrate", "param": {} + } + >>> mul_task_instance = MulTaskLearning( + estimator=estimator, + task_definition=task_definition, + task_relationship_discovery=task_relationship_discovery, + seen_task_allocation=seen_task_allocation, + task_remodeling=task_remodeling, + inference_integrate=inference_integrate + ) + + Notes + ----- + All method defined under `task_jobs` and registered in `ClassFactory`. + """ + + _method_pair = { + 'TaskDefinitionBySVC': 'TaskMiningBySVC', + 'TaskDefinitionByDataAttr': 'TaskMiningByDataAttr', + } + + def __init__(self, + estimator=None, + task_definition=None, + task_relationship_discovery=None, + seen_task_allocation=None, + task_remodeling=None, + inference_integrate=None + ): + + self.task_definition = task_definition or { + "method": "TaskDefinitionByDataAttr" + } + self.task_relationship_discovery = task_relationship_discovery or { + "method": "DefaultTaskRelationDiscover" + } + self.seen_task_allocation = seen_task_allocation or { + "method": "TaskAllocationDefault" + } + self.task_remodeling = task_remodeling or { + "method": "DefaultTaskRemodeling" + } + self.inference_integrate = inference_integrate or { + "method": "DefaultInferenceIntegrate" + } + + self.seen_models = None + self.seen_extractor = None + self.base_model = estimator + self.seen_task_groups = None + + self.min_train_sample = int(Context.get_parameters( + "MIN_TRAIN_SAMPLE", KBResourceConstant.MIN_TRAIN_SAMPLE.value + )) + + self.seen_task_key = KBResourceConstant.SEEN_TASK.value + self.task_group_key = KBResourceConstant.TASK_GROUPS.value + self.extractor_key = KBResourceConstant.EXTRACTOR.value + + self.log = LOGGER + + @staticmethod + def _parse_param(param_str): + if not param_str: + return {} + if isinstance(param_str, dict): + return param_str + try: + raw_dict = json.loads(param_str, encoding="utf-8") + except json.JSONDecodeError: + raw_dict = {} + return raw_dict + + def _task_definition(self, samples, **kwargs): + """ + Task attribute extractor and multi-task definition + """ + method_name = self.task_definition.get( + "method", "TaskDefinitionByDataAttr" + ) + extend_param = self._parse_param( + self.task_definition.get("param") + ) + method_cls = ClassFactory.get_cls( + ClassType.STP, method_name)(**extend_param) + return method_cls(samples, **kwargs) + + def _task_relationship_discovery(self, tasks): + """ + Merge tasks from task_definition + """ + method_name = self.task_relationship_discovery.get("method") + extend_param = self._parse_param( + self.task_relationship_discovery.get("param") + ) + method_cls = ClassFactory.get_cls( + ClassType.STP, method_name)(**extend_param) + return method_cls(tasks) + + def _task_allocation(self, samples): + """ + Mining tasks of inference sample base on task attribute extractor + """ + method_name = self.seen_task_allocation.get("method") + extend_param = self._parse_param( + self.seen_task_allocation.get("param") + ) + + if not method_name: + task_definition = self.task_definition.get( + "method", "TaskDefinitionByDataAttr" + ) + method_name = self._method_pair.get(task_definition, + 'TaskAllocationByDataAttr') + extend_param = self._parse_param( + self.task_definition.get("param")) + + method_cls = ClassFactory.get_cls(ClassType.STP, method_name)( + task_extractor=self.seen_extractor, **extend_param + ) + return method_cls(samples=samples) + + def _task_remodeling(self, samples, mappings): + """ + Remodeling tasks from task mining + """ + method_name = self.task_remodeling.get("method") + extend_param = self._parse_param( + self.task_remodeling.get("param")) + method_cls = ClassFactory.get_cls(ClassType.STP, method_name)( + models=self.seen_models, **extend_param) + return method_cls(samples=samples, mappings=mappings) + + def _inference_integrate(self, tasks): + """ + Aggregate inference results from target models + """ + method_name = self.inference_integrate.get("method") + extend_param = self._parse_param( + self.inference_integrate.get("param")) + method_cls = ClassFactory.get_cls(ClassType.STP, method_name)( + models=self.seen_models, **extend_param) + return method_cls(tasks=tasks) if method_cls else tasks + + def _task_process( + self, + task_groups, + train_data=None, + valid_data=None, + callback=None, + **kwargs): + """ + Train seen task samples based on grouped tasks. + """ + feedback = {} + rare_task = [] + for i, task in enumerate(task_groups): + if not isinstance(task, TaskGroup): + rare_task.append(i) + self.seen_models.append(None) + self.seen_task_groups.append(None) + continue + if not (task.samples and len(task.samples) + >= self.min_train_sample): + self.seen_models.append(None) + self.seen_task_groups.append(None) + rare_task.append(i) + n = len(task.samples) + LOGGER.info(f"Sample {n} of {task.entry} will be merge") + continue + LOGGER.info(f"MTL Train start {i} : {task.entry}") + + model = None + for t in task.tasks: # if model has train in tasks + if not (t.model and t.result): + continue + if isinstance(t.model, str): + model_path = t.model + else: + model_path = t.model.save(model_name=f"{task.entry}.model") + + t.model = model_path + model = Model(index=i, entry=t.entry, + model=model_path, result=t.result) + model.meta_attr = t.meta_attr + break + if not model: + model_obj = set_backend(estimator=self.base_model) + res = model_obj.train(train_data=task.samples, **kwargs) + if callback: + res = callback(model_obj, res) + if isinstance(res, str): + model_path = res + else: + model_path = model_obj.save( + model_name=f"{task.entry}.model") + model = Model(index=i, entry=task.entry, + model=model_path, result=res) + + model.meta_attr = [t.meta_attr for t in task.tasks] + task.model = model + self.seen_models.append(model) + feedback[task.entry] = model.result + self.seen_task_groups.append(task) + + if len(rare_task): + model_obj = set_backend(estimator=self.base_model) + res = model_obj.train(train_data=train_data, **kwargs) + model_path = model_obj.save(model_name="global.model") + for i in rare_task: + task = task_groups[i] + entry = getattr(task, 'entry', "global") + if not isinstance(task, TaskGroup): + task = TaskGroup( + entry=entry, tasks=[] + ) + model = Model(index=i, entry=entry, + model=model_path, result=res) + model.meta_attr = [t.meta_attr for t in task.tasks] + task.model = model + task.samples = train_data + self.seen_models[i] = model + feedback[entry] = res + self.seen_task_groups[i] = task + + task_index = { + self.extractor_key: self.seen_extractor, + self.task_group_key: self.seen_task_groups + } + + if valid_data: + feedback, _ = self.evaluate(valid_data, **kwargs) + return feedback, task_index + + def train(self, train_data: BaseDataSource, + valid_data: BaseDataSource = None, + post_process=None, **kwargs): + """ + fit for update the knowledge based on training data. + + Parameters + ---------- + train_data : BaseDataSource + Train data, see `sedna.datasources.BaseDataSource` for more detail. + valid_data : BaseDataSource + Valid data, BaseDataSource or None. + post_process : function + function or a registered method, callback after `estimator` train. + kwargs : Dict + parameters for `estimator` training, Like: + `early_stopping_rounds` in Xgboost.XGBClassifier + + Returns + ------- + feedback : Dict + contain all training result in each tasks. + task_index_url : str + task extractor model path, used for task allocation. + """ + + tasks, task_extractor, train_data = self._task_definition( + train_data, model=self.base_model, **kwargs) + self.seen_extractor = task_extractor + task_groups = self._task_relationship_discovery(tasks) + self.seen_models = [] + callback = None + if isinstance(post_process, str): + callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() + self.seen_task_groups = [] + + return self._task_process( + task_groups, + train_data=train_data, + valid_data=valid_data, + callback=callback) + + def update(self, tasks, task_update_strategies, **kwargs): + """ + Parameters: + ---------- + tasks: List[Task] + from the output of module task_update_decision + task_update_strategies: object + from the output of module task_update_decision + + Returns + ------- + task_index : Dict + updated seen task index of knowledge base + """ + if not (self.seen_models and self.seen_extractor): + self.load(kwargs.get("task_index", None)) + + task_groups = self._task_relationship_discovery(tasks) + + # TODO: to fit retraining + self.seen_task_groups = [] + self.seen_models = [] + + feedback = {} + for i, task in enumerate(task_groups): + LOGGER.info(f"MTL Train start {i} : {task.entry}") + for _task in task.tasks: + model_obj = set_backend(estimator=self.base_model) + model_obj.load(_task.model, phase="train") + res = model_obj.train(train_data=task.samples) + if isinstance(res, str): + model_path = res + model = Model(index=i, entry=task.entry, + model=model_path, result={}) + else: + model_path = model_obj.save( + model_name=f"{task.entry}_{time.time()}.model") + model = Model(index=i, entry=task.entry, + model=model_path, result=res) + + break + + model.meta_attr = [t.meta_attr for t in task.tasks] + task.model = model + self.seen_models.append(model) + feedback[task.entry] = model.result + self.seen_task_groups.append(task) + + task_index = { + self.extractor_key: {"real": 0, "sim": 1}, + self.task_group_key: self.seen_task_groups + } + + return task_index + + def load(self, task_index): + """ + load task_detail (tasks/models etc ...) from task index file. + It'll automatically loaded during `inference` and `evaluation` phases. + + Parameters + ---------- + task_index : str or Dict + task index file path, default self.task_index_url. + """ + assert task_index, "Task index can't be None." + + if isinstance(task_index, str): + task_index = FileOps.load(task_index) + + self.seen_extractor = task_index[self.seen_task_key][self.extractor_key] + if isinstance(self.seen_extractor, str): + self.seen_extractor = FileOps.load(self.seen_extractor) + self.seen_task_groups = task_index[self.seen_task_key][self.task_group_key] + self.seen_models = [task.model for task in self.seen_task_groups] + + def predict(self, data: BaseDataSource, + post_process=None, **kwargs): + """ + predict the result for input data based on training knowledge. + + Parameters + ---------- + data : BaseDataSource + inference sample, see `sedna.datasources.BaseDataSource` for + more detail. + post_process: function + function or a registered method, effected after `estimator` + prediction, like: label transform. + kwargs: Dict + parameters for `estimator` predict, Like: + `ntree_limit` in Xgboost.XGBClassifier + + Returns + ------- + result : array_like + results array, contain all inference results in each sample. + tasks : List + tasks assigned to each sample. + """ + if not (self.seen_models and self.seen_extractor): + self.load(kwargs.get("task_index", None)) + + data, mappings = self._task_allocation(samples=data) + samples, models = self._task_remodeling(samples=data, + mappings=mappings + ) + + callback = None + if post_process: + callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() + + tasks = [] + for inx, df in enumerate(samples): + m = models[inx] + if not isinstance(m, Model): + continue + if isinstance(m.model, str): + evaluator = set_backend(estimator=self.base_model) + evaluator.load(m.model) + else: + evaluator = m.model + pred = evaluator.predict(df.x, **kwargs) + if callable(callback): + pred = callback(pred, df) + task = Task(entry=m.entry, samples=df) + task.result = pred + task.model = m + tasks.append(task) + res = self._inference_integrate(tasks) + return res, tasks + + def evaluate(self, data: BaseDataSource, + metrics=None, + metrics_param=None, + **kwargs): + """ + evaluated the performance of each task from training, filter tasks + based on the defined rules. + + Parameters + ---------- + data : BaseDataSource + valid data, see `sedna.datasources.BaseDataSource` for more detail. + metrics : function / str + Metrics to assess performance on the task by given prediction. + metrics_param : Dict + parameter for metrics function. + kwargs: Dict + parameters for `estimator` evaluate, Like: + `ntree_limit` in Xgboost.XGBClassifier + + Returns + ------- + task_eval_res : Dict + all metric results. + tasks_detail : List[Object] + all metric results in each task. + """ + result, tasks = self.predict(data, **kwargs) + m_dict = {} + + if metrics: + if callable(metrics): # if metrics is a function + m_name = getattr(metrics, '__name__', "mtl_eval") + m_dict = { + m_name: metrics + } + elif isinstance(metrics, (set, list)): # if metrics is multiple + for inx, m in enumerate(metrics): + m_name = getattr(m, '__name__', f"mtl_eval_{inx}") + if isinstance(m, str): + m = getattr(sk_metrics, m) + if not callable(m): + continue + m_dict[m_name] = m + elif isinstance(metrics, str): # if metrics is single + m_dict = { + metrics: getattr(sk_metrics, metrics, sk_metrics.log_loss) + } + elif isinstance(metrics, dict): # if metrics with name + for k, v in metrics.items(): + if isinstance(v, str): + v = getattr(sk_metrics, v) + if not callable(v): + continue + m_dict[k] = v + + if not len(m_dict): + m_dict = { + 'precision_score': sk_metrics.precision_score + } + metrics_param = {"average": "micro"} + + if isinstance(data.x, pd.DataFrame): + data.x['pred_y'] = result + data.x['real_y'] = data.y + if not metrics_param: + metrics_param = {} + elif isinstance(metrics_param, str): + metrics_param = self._parse_param(metrics_param) + tasks_detail = [] + for task in tasks: + sample = task.samples + pred = task.result + scores = { + name: metric(sample.y, pred, **metrics_param) + for name, metric in m_dict.items() + } + task.scores = scores + tasks_detail.append(task) + task_eval_res = { + name: metric(data.y, result, **metrics_param) + for name, metric in m_dict.items() + } + return task_eval_res, tasks_detail diff --git a/lib/sedna/algorithms/seen_task_learning/task_allocation/__init__.py b/lib/sedna/algorithms/seen_task_learning/task_allocation/__init__.py new file mode 100644 index 000000000..f318c3faa --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_allocation/__init__.py @@ -0,0 +1,4 @@ + +from . import task_allocation +from . import task_allocation_by_origin + diff --git a/lib/sedna/algorithms/seen_task_learning/task_allocation/task_allocation.py b/lib/sedna/algorithms/seen_task_learning/task_allocation/task_allocation.py new file mode 100644 index 000000000..08dec8ffe --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_allocation/task_allocation.py @@ -0,0 +1,116 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mining tasks of inference sample base on task attribute extractor + +Parameters +---------- +samples : infer sample, see `sedna.datasources.BaseDataSource` for more detail. + +Returns +------- +allocations : tasks that assigned to each sample +""" + +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + + +__all__ = ( + 'TaskAllocationBySVC', + 'TaskAllocationByDataAttr', + 'TaskAllocationDefault', +) + + +@ClassFactory.register(ClassType.STP) +class TaskAllocationBySVC: + """ + Corresponding to `TaskDefinitionBySVC` + + Parameters + ---------- + task_extractor : Model + SVC Model used to predicting target tasks + """ + + def __init__(self, task_extractor, **kwargs): + self.task_extractor = task_extractor + + def __call__(self, samples: BaseDataSource): + df = samples.x + allocations = [0, ] * len(df) + legal = list( + filter(lambda col: df[col].dtype == 'float64', df.columns)) + if not len(legal): + return allocations + + allocations = list(self.task_extractor.predict(df[legal])) + return samples, allocations + + +@ClassFactory.register(ClassType.STP) +class TaskAllocationByDataAttr: + """ + Corresponding to `TaskDefinitionByDataAttr` + + Parameters + ---------- + task_extractor : Dict + used to match target tasks + attr_filed: List[Metadata] + metadata is usually a class feature + label with a finite values. + """ + + def __init__(self, task_extractor, **kwargs): + self.task_extractor = task_extractor + self.attr_filed = kwargs.get("attribute", []) + + def __call__(self, samples: BaseDataSource): + df = samples.x + meta_attr = df[self.attr_filed] + + allocations = meta_attr.apply( + lambda x: self.task_extractor.get( + "_".join( + map(lambda y: str(x[y]).replace("_", "-").replace(" ", ""), + self.attr_filed) + ), + 0), + axis=1).values.tolist() + samples.x = df.drop(self.attr_filed, axis=1) + samples.meta_attr = meta_attr + return samples, allocations + + +@ClassFactory.register(ClassType.STP) +class TaskAllocationDefault: + """ + Task allocation specifically for unstructured data + + Parameters + ---------- + task_extractor : Dict + used to match target tasks + """ + + def __init__(self, task_extractor, **kwargs): + self.task_extractor = task_extractor + + def __call__(self, samples: BaseDataSource): + allocations = [0] * len(samples) + + return samples, allocations diff --git a/lib/sedna/algorithms/seen_task_learning/task_allocation/task_allocation_by_origin.py b/lib/sedna/algorithms/seen_task_learning/task_allocation/task_allocation_by_origin.py new file mode 100644 index 000000000..7b86ebfbb --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_allocation/task_allocation_by_origin.py @@ -0,0 +1,35 @@ +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + +@ClassFactory.register(ClassType.STP) +class TaskAllocationByOrigin: + """ + Corresponding to `TaskDefinitionByOrigin` + + Parameters + ---------- + task_extractor : Dict + used to match target tasks + origins: List[Metadata] + metadata is usually a class feature + label with a finite values. + """ + + def __init__(self, task_extractor, **kwargs): + self.task_extractor = task_extractor + self.origins = kwargs.get("origins", []) + self.default_origin = kwargs.get("default", None) + + def __call__(self, samples: BaseDataSource): + if self.default_origin: + return samples, [int(self.task_extractor.get(self.default_origin))] * len(samples.x) + + sample_origins = [] + for _x in samples.x: + for origin in self.origins: + if origin in _x[0]: + sample_origins.append(origin) + + allocations = [int(self.task_extractor.get(sample_origin)) for sample_origin in sample_origins] + + return samples, allocations \ No newline at end of file diff --git a/lib/sedna/algorithms/seen_task_learning/task_definition/__init__.py b/lib/sedna/algorithms/seen_task_learning/task_definition/__init__.py new file mode 100644 index 000000000..7d2bcb74e --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_definition/__init__.py @@ -0,0 +1,3 @@ + +from . import task_definition +from . import task_definition_by_origin diff --git a/lib/sedna/algorithms/seen_task_learning/task_definition/task_definition.py b/lib/sedna/algorithms/seen_task_learning/task_definition/task_definition.py new file mode 100644 index 000000000..3cadc3f7b --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_definition/task_definition.py @@ -0,0 +1,213 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Divide multiple tasks based on data + +Parameters +---------- +samples: Train data, see `sedna.datasources.BaseDataSource` for more detail. + +Returns +------- +tasks: All tasks based on training data. +task_extractor: Model with a method to predicting target tasks +""" + +from typing import List, Any, Tuple +import time +import numpy as np +import pandas as pd + +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassType, ClassFactory + +from ..artifact import Task + + +__all__ = ( + 'TaskDefinitionBySVC', + 'TaskDefinitionByDataAttr', + 'TaskDefinitionByCluster') + + +@ClassFactory.register(ClassType.STP) +class TaskDefinitionBySVC: + """ + Dividing datasets with `AgglomerativeClustering` based on kernel distance, + Using SVC to fit the clustering result. + + Parameters + ---------- + n_class: int or None + The number of clusters to find, default=2. + """ + + def __init__(self, **kwargs): + n_class = kwargs.get("n_class", "") + self.n_class = max(2, int(n_class)) if str(n_class).isdigit() else 2 + + def __call__(self, + samples: BaseDataSource) -> Tuple[List[Task], + Any, + BaseDataSource]: + from sklearn.svm import SVC + from sklearn.cluster import AgglomerativeClustering + + d_type = samples.data_type + x_data = samples.x + y_data = samples.y + if not isinstance(x_data, pd.DataFrame): + raise TypeError(f"{d_type} data should only be pd.DataFrame") + tasks = [] + legal = list( + filter(lambda col: x_data[col].dtype == 'float64', x_data.columns)) + + df = x_data[legal] + c1 = AgglomerativeClustering(n_clusters=self.n_class).fit_predict(df) + c2 = SVC(gamma=0.01) + c2.fit(df, c1) + + for task in range(self.n_class): + g_attr = f"svc_{task}" + task_df = BaseDataSource(data_type=d_type) + task_df.x = x_data.iloc[np.where(c1 == task)] + task_df.y = y_data.iloc[np.where(c1 == task)] + + task_obj = Task(entry=g_attr, samples=task_df) + tasks.append(task_obj) + samples.x = df + return tasks, c2, samples + + +@ClassFactory.register(ClassType.STP) +class TaskDefinitionByDataAttr: + """ + Dividing datasets based on the common attributes, + generally used for structured data. + + Parameters + ---------- + attribute: List[Metadata] + metadata is usually a class feature label with a finite values. + """ + + def __init__(self, **kwargs): + self.attr_filed = kwargs.get("attribute", []) + + def __call__(self, + samples: BaseDataSource, **kwargs) -> Tuple[List[Task], + Any, + BaseDataSource]: + tasks = [] + d_type = samples.data_type + x_data = samples.x + y_data = samples.y + if not isinstance(x_data, pd.DataFrame): + raise TypeError(f"{d_type} data should only be pd.DataFrame") + + _inx = 0 + task_index = {} + for meta_attr, df in x_data.groupby(self.attr_filed): + if isinstance(meta_attr, (list, tuple, set)): + g_attr = "_".join( + map(lambda x: str(x).replace("_", "-"), meta_attr)) + meta_attr = list(meta_attr) + else: + g_attr = str(meta_attr).replace("_", "-") + meta_attr = [meta_attr] + g_attr = g_attr.replace(" ", "") + if g_attr in task_index: + old_task = tasks[task_index[g_attr]] + old_task.x = pd.concat([old_task.x, df]) + old_task.y = pd.concat([old_task.y, y_data.iloc[df.index]]) + continue + task_index[g_attr] = _inx + + task_df = BaseDataSource(data_type=d_type) + task_df.x = df.drop(self.attr_filed, axis=1) + task_df.y = y_data.iloc[df.index] + + task_obj = Task(entry=g_attr, samples=task_df, meta_attr=meta_attr) + tasks.append(task_obj) + _inx += 1 + x_data.drop(self.attr_filed, axis=1, inplace=True) + samples = BaseDataSource(data_type=d_type) + samples.x = x_data + samples.y = y_data + return tasks, task_index, samples + + +@ClassFactory.register(ClassType.STP) +class TaskDefinitionByCluster: + """ + Dividing datasets with all sorts of clustering methods. + + Parameters + ---------- + n_class: int or None + The number of clusters to find, default=1. + """ + + def __init__(self, **kwargs): + self.n_class = int(kwargs.get("n_class", 1)) + self.train_ratio = float(kwargs.get("train_ratio", 0.8)) + + def __call__(self, + samples: BaseDataSource, **kwargs) -> Tuple[List[Task], + Any, + BaseDataSource]: + from sklearn.svm import SVC + model = kwargs.get("model") + + tasks = [] + c2 = SVC(gamma=0.01) + partition_length = int(len(samples.x) / self.n_class) + + for i in range(self.n_class): + # sample = BaseDataSource() + # sample.x = samples.x[i * partition_length: (i + 1) * partition_length] + # sample.y = samples.y[i * partition_length: (i + 1) * partition_length] + # + # train_num = int(len(sample.x) * train_ratio) + # + # train_samples = BaseDataSource(data_type="train") + # train_samples.x = sample.x[:train_num] + # train_samples.y = sample.y[:train_num] + # + # test_samples = BaseDataSource(data_type="eval") + # test_samples.x = sample.x[train_num:] + # test_samples.y = sample.y[train_num:] + + train_num = int(len(samples.x) * self.train_ratio) + train_samples = BaseDataSource(data_type="train") + train_samples.x = samples.x[:train_num] + train_samples.y = samples.y[:train_num] + + test_samples = BaseDataSource(data_type="eval") + test_samples.x = samples.x[train_num:] + test_samples.y = samples.y[train_num:] + + model_url = model.train(train_samples, test_samples) + model.load(model_url) + result = model.evaluate(test_samples, checkpoint_path=model_url) + + g_attr = f"svc_{i}_{time.time()}" + task_obj = Task(entry=g_attr, samples=train_samples) + task_obj.test_samples = test_samples + task_obj.model = model_url + task_obj.result = result + tasks.append(task_obj) + + return tasks, c2, samples diff --git a/lib/sedna/algorithms/seen_task_learning/task_definition/task_definition_by_origin.py b/lib/sedna/algorithms/seen_task_learning/task_definition/task_definition_by_origin.py new file mode 100644 index 000000000..460aff5f9 --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_definition/task_definition_by_origin.py @@ -0,0 +1,49 @@ +from typing import List, Any, Tuple + +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassType, ClassFactory + +from ..artifact import Task + +@ClassFactory.register(ClassType.STP) +class TaskDefinitionByOrigin: + """ + Dividing datasets based on the their origins. + + Parameters + ---------- + attribute: List[Metadata] + metadata is usually a class feature label with a finite values. + """ + + def __init__(self, **kwargs): + self.origins = kwargs.get("origins", []) + + def __call__(self, + samples: BaseDataSource, **kwargs) -> Tuple[List[Task], + Any, + BaseDataSource]: + tasks = [] + d_type = samples.data_type + x_data = samples.x + y_data = samples.y + + task_index = dict(zip(self.origins, range(len(self.origins)))) + + for k, v in task_index.items(): + _x = [x for x in x_data if k in x[0]] + _y = [y for y in y_data if k in y] + + task_df = BaseDataSource(data_type=d_type) + task_df.x = _x + task_df.y = _y + + g_attr = f"{k}_semantic_segamentation_model" + task_obj = Task(entry=g_attr, samples=task_df, meta_attr=k) + tasks.append(task_obj) + + samples = BaseDataSource(data_type=d_type) + samples.x = x_data + samples.y = y_data + + return tasks, task_index, samples \ No newline at end of file diff --git a/lib/sedna/algorithms/seen_task_learning/task_relation_discover/__init__.py b/lib/sedna/algorithms/seen_task_learning/task_relation_discover/__init__.py new file mode 100644 index 000000000..229b95840 --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_relation_discover/__init__.py @@ -0,0 +1,3 @@ + +from . import task_relation_discover + diff --git a/lib/sedna/algorithms/seen_task_learning/task_relation_discover/task_relation_discover.py b/lib/sedna/algorithms/seen_task_learning/task_relation_discover/task_relation_discover.py new file mode 100644 index 000000000..bbbe8a061 --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_relation_discover/task_relation_discover.py @@ -0,0 +1,52 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Discover relationships between all tasks + +Parameters +---------- +tasks :all tasks form `task_definition` + +Returns +------- +task_groups : List of groups which including at least 1 task. +""" + +from typing import List + +from sedna.common.class_factory import ClassType, ClassFactory + +from ..artifact import Task, TaskGroup + + +__all__ = ('DefaultTaskRelationDiscover',) + + +@ClassFactory.register(ClassType.STP) +class DefaultTaskRelationDiscover: + """ + Assume that each task is independent of each other + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, tasks: List[Task]) -> List[TaskGroup]: + tgs = [] + for task in tasks: + tg_obj = TaskGroup(entry=task.entry, tasks=[task]) + tg_obj.samples = task.samples + tgs.append(tg_obj) + return tgs diff --git a/lib/sedna/algorithms/seen_task_learning/task_remodeling/__init__.py b/lib/sedna/algorithms/seen_task_learning/task_remodeling/__init__.py new file mode 100644 index 000000000..33df6b2cd --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_remodeling/__init__.py @@ -0,0 +1,2 @@ + +from . import task_remodeling diff --git a/lib/sedna/algorithms/seen_task_learning/task_remodeling/task_remodeling.py b/lib/sedna/algorithms/seen_task_learning/task_remodeling/task_remodeling.py new file mode 100644 index 000000000..7e7a92271 --- /dev/null +++ b/lib/sedna/algorithms/seen_task_learning/task_remodeling/task_remodeling.py @@ -0,0 +1,78 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Remodeling tasks based on their relationships + +Parameters +---------- +mappings :all assigned tasks get from the `task_mining` +samples : input samples + +Returns +------- +models : List of groups which including at least 1 task. +""" + +from typing import List + +import numpy as np +import pandas as pd + +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('DefaultTaskRemodeling',) + + +@ClassFactory.register(ClassType.STP) +class DefaultTaskRemodeling: + """ + Assume that each task is independent of each other + """ + + def __init__(self, models: list, **kwargs): + self.models = models + + def __call__(self, samples: BaseDataSource, mappings: List): + """ + Grouping based on assigned tasks + """ + mappings = np.array(mappings) + data, models = [], [] + d_type = samples.data_type + for m in np.unique(mappings): + task_df = BaseDataSource(data_type=d_type) + _inx = np.where(mappings == m) + if isinstance(samples.x, pd.DataFrame): + task_df.x = samples.x.iloc[_inx] + else: + task_df.x = np.array(samples.x)[_inx] + if d_type != "test": + if isinstance(samples.x, pd.DataFrame): + task_df.y = samples.y.iloc[_inx] + else: + task_df.y = np.array(samples.y)[_inx] + task_df.inx = _inx[0].tolist() + if samples.meta_attr is not None: + task_df.meta_attr = np.array(samples.meta_attr)[_inx] + data.append(task_df) + # TODO: if m is out of index + try: + model = self.models[m] + except Exception as err: + print(f"self.models[{m}] not exists. {err}") + model = self.models[0] + models.append(model) + return data, models diff --git a/lib/sedna/algorithms/unseen_task_detection/__init__.py b/lib/sedna/algorithms/unseen_task_detection/__init__.py new file mode 100644 index 000000000..c5c182528 --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_detection/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import unseen_sample_recognition +from . import unseen_sample_re_recognition diff --git a/lib/sedna/algorithms/unseen_task_detection/unseen_sample_re_recognition/__init__.py b/lib/sedna/algorithms/unseen_task_detection/unseen_sample_re_recognition/__init__.py new file mode 100644 index 000000000..fd9d4f54f --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_detection/unseen_sample_re_recognition/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import unseen_sample_re_recognition diff --git a/lib/sedna/algorithms/unseen_task_detection/unseen_sample_re_recognition/unseen_sample_re_recognition.py b/lib/sedna/algorithms/unseen_task_detection/unseen_sample_re_recognition/unseen_sample_re_recognition.py new file mode 100644 index 000000000..3222b8f20 --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_detection/unseen_sample_re_recognition/unseen_sample_re_recognition.py @@ -0,0 +1,43 @@ +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('SampleReRegonitionDefault', ) + + +@ClassFactory.register(ClassType.UTD) +class SampleReRegonitionDefault: + # TODO: to be completed + ''' + Divide inference samples into seen tasks and unseen tasks. + + Parameters + ---------- + task_index: str or Dict + ''' + + def __init__(self, task_index, **kwargs): + pass + + def __call__(self, samples: BaseDataSource): + ''' + Parameters + ---------- + samples: training samples + + Returns + ------- + seen_task_samples: BaseDataSource + unseen_task_samples: BaseDataSource + ''' + + sample_num = int(len(samples.x) / 2) + + seen_task_samples = BaseDataSource(data_type=samples.data_type) + seen_task_samples.x = samples.x[:sample_num] + seen_task_samples.y = samples.y[:sample_num] + + unseen_task_samples = BaseDataSource(data_type=samples.data_type) + unseen_task_samples.x = samples.x[sample_num:] + unseen_task_samples.y = samples.y[sample_num:] + + return seen_task_samples, unseen_task_samples diff --git a/lib/sedna/algorithms/unseen_task_detection/unseen_sample_recognition/__init__.py b/lib/sedna/algorithms/unseen_task_detection/unseen_sample_recognition/__init__.py new file mode 100644 index 000000000..ab3fe7e57 --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_detection/unseen_sample_recognition/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import unseen_sample_recognition diff --git a/lib/sedna/algorithms/unseen_task_detection/unseen_sample_recognition/unseen_sample_recognition.py b/lib/sedna/algorithms/unseen_task_detection/unseen_sample_recognition/unseen_sample_recognition.py new file mode 100644 index 000000000..992ee4f0b --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_detection/unseen_sample_recognition/unseen_sample_recognition.py @@ -0,0 +1,95 @@ +from typing import Tuple + +from sedna.common.file_ops import FileOps +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('SampleRegonitionDefault', 'SampleRegonitionByRFNet') + + +@ClassFactory.register(ClassType.UTD) +class SampleRegonitionDefault: + ''' + Divide inference samples into seen samples and unseen samples + + Parameters + ---------- + task_index: str or dict + knowledge base index which includes indexes of tasks, samples and etc. + ''' + + def __init__(self, task_index, **kwargs): + if isinstance(task_index, str) and FileOps.exists(task_index): + self.task_index = FileOps.load(task_index) + else: + self.task_index = task_index + + def __call__(self, + samples: BaseDataSource) -> Tuple[BaseDataSource, + BaseDataSource]: + ''' + Parameters + ---------- + samples : BaseDataSource + inference samples + + Returns + ------- + seen_task_samples: BaseDataSource + unseen_task_samples: BaseDataSource + ''' + sample_num = int(len(samples.x) / 2) + + seen_task_samples = BaseDataSource(data_type=samples.data_type) + seen_task_samples.x = samples.x[sample_num:] + + unseen_task_samples = BaseDataSource(data_type=samples.data_type) + unseen_task_samples.x = samples.x[:sample_num] + + return seen_task_samples, unseen_task_samples + + +@ClassFactory.register(ClassType.UTD) +class SampleRegonitionByRFNet: + ''' + Divide inference samples into seen samples and unseen samples by confidence. + + Parameters + ---------- + task_index: str or dict + knowledge base index which includes indexes of tasks, samples and etc. + ''' + + def __init__(self, task_index, **kwargs): + if isinstance(task_index, str) and FileOps.exists(task_index): + self.task_index = FileOps.load(task_index) + else: + self.task_index = task_index + + self.validator = kwargs.get("validator") + + def __call__(self, samples: BaseDataSource, ** + kwargs) -> Tuple[BaseDataSource, BaseDataSource]: + ''' + Parameters + ---------- + samples : BaseDataSource + inference samples + + Returns + ------- + seen_task_samples: BaseDataSource + unseen_task_samples: BaseDataSource + ''' + from torch.utils.data import DataLoader + + self.validator.test_loader = DataLoader( + samples.x, batch_size=1, shuffle=False) + + seen_task_samples = BaseDataSource(data_type=samples.data_type) + unseen_task_samples = BaseDataSource(data_type=samples.data_type) + + seen_task_samples.x, unseen_task_samples.x = self.validator.task_divide() + + return seen_task_samples, unseen_task_samples + diff --git a/lib/sedna/algorithms/unseen_task_processing/__init__.py b/lib/sedna/algorithms/unseen_task_processing/__init__.py new file mode 100644 index 000000000..431955b7d --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_processing/__init__.py @@ -0,0 +1,2 @@ +from . import unseen_task_allocation +from .unseen_task_processing import UnseenTaskProcessing diff --git a/lib/sedna/algorithms/unseen_task_processing/unseen_task_allocation/__init__.py b/lib/sedna/algorithms/unseen_task_processing/unseen_task_allocation/__init__.py new file mode 100644 index 000000000..281958c9f --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_processing/unseen_task_allocation/__init__.py @@ -0,0 +1 @@ +from . import unseen_task_allocation diff --git a/lib/sedna/algorithms/unseen_task_processing/unseen_task_allocation/unseen_task_allocation.py b/lib/sedna/algorithms/unseen_task_processing/unseen_task_allocation/unseen_task_allocation.py new file mode 100644 index 000000000..d49a8f1b6 --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_processing/unseen_task_allocation/unseen_task_allocation.py @@ -0,0 +1,46 @@ +from sedna.common.log import LOGGER +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('UnseenTaskAllocationDefault', ) + + +@ClassFactory.register(ClassType.UTP) +class UnseenTaskAllocationDefault: + # TODO: to be completed + """ + Task allocation for unseen data + + Parameters + ---------- + task_extractor : Dict + used to match target tasks + """ + + def __init__(self, task_extractor, **kwargs): + self.task_extractor = task_extractor + self.log = LOGGER + + def __call__(self, samples: BaseDataSource): + ''' + Parameters + ---------- + samples: samples to be allocated + + Returns + ------- + samples: BaseDataSource + allocations: List + allocation decision for actual inference + ''' + + try: + allocations = [self.task_extractor.fit( + sample) for sample in samples.x] + except Exception as err: + self.log.exception(err) + + allocations = [0] * len(samples) + self.log.info("Use the first task to inference all the samples.") + + return samples, allocations diff --git a/lib/sedna/algorithms/unseen_task_processing/unseen_task_processing.py b/lib/sedna/algorithms/unseen_task_processing/unseen_task_processing.py new file mode 100644 index 000000000..73c761606 --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_processing/unseen_task_processing.py @@ -0,0 +1,157 @@ +import json + +from sedna.backend import set_backend +from sedna.common.file_ops import FileOps +from sedna.common.constant import KBResourceConstant +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('UnseenTaskProcessing', ) + + +class UnseenTaskProcessing: + ''' + Process unseen tasks given task update strategies + + Parameters: + ---------- + estimator: Instance + An instance with the high-level API that greatly simplifies + machine learning programming. Estimators encapsulate training, + evaluation, prediction, and exporting for your model. + cloud_knowledge_management: Instance of class CloudKnowledgeManagement + unseen_task_allocation: Dict + Mining tasks of unseen inference sample. + ''' + + def __init__(self, estimator, config, + cloud_knowledge_management, + edge_knowledge_management, + unseen_task_allocation, + **kwargs): + self.estimator = set_backend(estimator=estimator, config=config) + self.cloud_knowledge_management = cloud_knowledge_management + self.edge_knowledge_management = edge_knowledge_management + + self.unseen_task_allocation = unseen_task_allocation or { + "method": "UnseenTaskAllocationDefault" + } + self.unseen_models = None + self.unseen_extractor = None + self.unseen_task_groups = None + self.unseen_task_key = KBResourceConstant.UNSEEN_TASK.value + self.task_group_key = KBResourceConstant.TASK_GROUPS.value + self.extractor_key = KBResourceConstant.EXTRACTOR.value + + @staticmethod + def _parse_param(param_str): + if not param_str: + return {} + if isinstance(param_str, dict): + return param_str + try: + raw_dict = json.loads(param_str, encoding="utf-8") + except json.JSONDecodeError: + raw_dict = {} + return raw_dict + + def _unseen_task_allocation(self, samples): + """ + Mining unseen tasks of inference sample base on task attribute extractor + """ + method_name = self.unseen_task_allocation.get("method") + extend_param = self._parse_param( + self.unseen_task_allocation.get("param") + ) + + method_cls = ClassFactory.get_cls(ClassType.UTP, method_name)( + task_extractor=self.unseen_extractor, **extend_param + ) + return method_cls(samples=samples) + + def initialize(self): + """ + Intialize unseen task groups + + Returns: + res: Dict + evaluation result. + task_index: Dict or str + unseen task index which includes models, samples, extractor and etc. + """ + task_index = { + self.extractor_key: None, + self.task_group_key: [] + } + + res = {} + return res, task_index + + def update(self, tasks, task_update_strategies, **kwargs): + """ + Parameters: + ---------- + tasks: List[Task] + from the output of module task_update_decision + task_update_strategies: Dict + from the output of module task_update_decision + + Returns + ------- + task_index : Dict + updated unseen task index of knowledge base + """ + task_index = { + self.extractor_key: None, + self.task_group_key: [] + } + + return task_index + + def predict(self, data, post_process=None, **kwargs): + """ + Predict the result for unseen data. + + Parameters + ---------- + data : BaseDataSource + inference sample, see `sedna.datasources.BaseDataSource` for + more detail. + post_process: function + function or a registered method, effected after `estimator` + prediction, like: label transform. + + Returns + ------- + result : array_like + results array, contain all inference results in each sample. + tasks : List + tasks assigned to each sample. + """ + if not self.unseen_task_groups and not self.unseen_models: + self.load(self.edge_knowledge_management.task_index) + + samples, mappings = self._unseen_task_allocation(data) + result = {} + tasks = [] + + return result, tasks + + def load(self, task_index): + """ + load task_detail (tasks/models etc ...) from task index file. + It'll automatically loaded during `inference` phases. + + Parameters + ---------- + task_index_url : str + task index file path. + """ + assert task_index is not None, "task index url is None!!!" + if isinstance(task_index, str): + task_index = FileOps.load(task_index) + + self.unseen_extractor = task_index[self.unseen_task_key][self.extractor_key] + if isinstance(self.unseen_extractor, str): + self.unseen_extractor = FileOps.load(self.unseen_extractor) + self.unseen_task_groups = task_index[self.unseen_task_key][self.task_group_key] + self.unseen_models = [task.model for task in self.unseen_task_groups] diff --git a/lib/sedna/backend/base.py b/lib/sedna/backend/base.py index 72023b7eb..1a56873b4 100644 --- a/lib/sedna/backend/base.py +++ b/lib/sedna/backend/base.py @@ -58,6 +58,15 @@ def train(self, *args, **kwargs): varkw = self.parse_kwargs(fit_method, **kwargs) return fit_method(*args, **varkw) + def update(self, *args, **kwargs): + """update model by training.""" + if callable(self.estimator): + varkw = self.parse_kwargs(self.estimator, **kwargs) + self.estimator = self.estimator(**varkw) + fit_method = getattr(self.estimator, "fit", self.estimator.update) + varkw = self.parse_kwargs(fit_method, **kwargs) + return fit_method(*args, **varkw) + def predict(self, *args, **kwargs): """Inference model.""" varkw = self.parse_kwargs(self.estimator.predict, **kwargs) diff --git a/lib/sedna/common/class_factory.py b/lib/sedna/common/class_factory.py index 13b6ceca3..82b4b114e 100644 --- a/lib/sedna/common/class_factory.py +++ b/lib/sedna/common/class_factory.py @@ -37,6 +37,11 @@ class ClassType: DATASET = 'data_process' CALLBACK = 'post_process_callback' + # TODO + UTP = 'unseen_task_processing' + KM = 'knowledge_management' + STP = 'seen_task_processing' + class ClassFactory(object): """ diff --git a/lib/sedna/common/constant.py b/lib/sedna/common/constant.py index b32aed3b5..d84784fa0 100644 --- a/lib/sedna/common/constant.py +++ b/lib/sedna/common/constant.py @@ -48,3 +48,8 @@ class KBResourceConstant(Enum): MIN_TRAIN_SAMPLE = 10 KB_INDEX_NAME = "index.pkl" TASK_EXTRACTOR_NAME = "task_attr_extractor.pkl" + SEEN_TASK = "seen_task" + UNSEEN_TASK = "unseen_task" + TASK_GROUPS = "task_groups" + EXTRACTOR = "extractor" + EDGE_KB_DIR = "/var/lib/sedna/kb" diff --git a/lib/sedna/core/lifelong_learning/lifelong_learning.py b/lib/sedna/core/lifelong_learning/lifelong_learning.py index ff28cc847..1cf231d78 100644 --- a/lib/sedna/core/lifelong_learning/lifelong_learning.py +++ b/lib/sedna/core/lifelong_learning/lifelong_learning.py @@ -12,18 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import tempfile +import numpy as np -from sedna.backend import set_backend from sedna.core.base import JobBase from sedna.common.file_ops import FileOps from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus from sedna.common.constant import KBResourceConstant from sedna.common.config import Context +from sedna.datasources import BaseDataSource from sedna.common.class_factory import ClassType, ClassFactory -from sedna.algorithms.multi_task_learning import MulTaskLearning +from sedna.algorithms.seen_task_learning.seen_task_learning import SeenTaskLearning +from sedna.algorithms.unseen_task_processing import UnseenTaskProcessing from sedna.service.client import KBClient +from sedna.algorithms.knowledge_management.cloud_knowledge_management \ + import CloudKnowledgeManagement +from sedna.algorithms.knowledge_management.edge_knowledge_management \ + import EdgeKnowledgeManagement class LifelongLearning(JobBase): @@ -42,23 +46,31 @@ class LifelongLearning(JobBase): evaluation, prediction, and exporting for your model. task_definition : Dict Divide multiple tasks based on data, - see `task_jobs.task_definition` for more detail. + see `task_definition.task_definition` for more detail. task_relationship_discovery : Dict Discover relationships between all tasks, see - `task_jobs.task_relationship_discovery` for more detail. - task_mining : Dict - Mining tasks of inference sample, - see `task_jobs.task_mining` for more detail. + `task_relationship_discovery.task_relationship_discovery` for more detail. + task_allocation : Dict + Mining seen tasks of inference sample, + see `task_allocation.task_allocation` for more detail. task_remodeling : Dict Remodeling tasks based on their relationships, - see `task_jobs.task_remodeling` for more detail. + see `task_remodeling.task_remodeling` for more detail. inference_integrate : Dict Integrate the inference results of all related - tasks, see `task_jobs.inference_integrate` for more detail. - unseen_task_detect: Dict - unseen task detect algorithms with parameters which has registered to - ClassFactory, see `sedna.algorithms.unseen_task_detect` for more detail - + tasks, see `inference_integrate.inference_integrate` for more detail. + task_update_decision: Dict + Task update strategy making algorithms, + see 'knowledge_management.task_update_decision.task_update_decision' for more detail. + unseen_task_allocation: Dict + Mining unseen tasks of inference sample, + see `unseen_task_processing.unseen_task_allocation.unseen_task_allocation` for more detail. + unseen_sample_recognition: Dict + Dividing inference samples into seen tasks and unseen tasks, + see 'unseen_task_processing.unseen_sample_recognition.unseen_sample_recognition' for more detail. + unseen_sample_re_recognition: Dict + Dividing unseen training samples into seen tasks and unseen tasks, + see 'unseen_task_processing.unseen_sample_re_recognition.unseen_sample_re_recognition' for more detail. Examples -------- @@ -78,17 +90,29 @@ class LifelongLearning(JobBase): >>> inference_integrate = { "method": "DefaultInferenceIntegrate", "param": {} } - >>> unseen_task_detect = { - "method": "TaskAttrFilter", "param": {} + >>> task_update_decision = { + "method": "UpdateStrategyDefault", "param": {} + } + >>> unseen_task_allocation = { + "method": "UnseenTaskAllocationDefault", "param": {} + } + >>> unseen_sample_recognition = { + "method": "SampleRegonitionDefault", "param": {} + } + >>> unseen_sample_re_recognition = { + "method": "SampleReRegonitionDefault", "param": {} } >>> ll_jobs = LifelongLearning( - estimator=estimator, - task_definition=task_definition, - task_relationship_discovery=task_relationship_discovery, - task_mining=task_mining, - task_remodeling=task_remodeling, - inference_integrate=inference_integrate, - unseen_task_detect=unseen_task_detect + estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=None, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=None, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None, ) """ @@ -96,37 +120,72 @@ def __init__(self, estimator, task_definition=None, task_relationship_discovery=None, - task_mining=None, + task_allocation=None, task_remodeling=None, inference_integrate=None, - unseen_task_detect=None): - - if not task_definition: - task_definition = { - "method": "TaskDefinitionByDataAttr" - } - if not unseen_task_detect: - unseen_task_detect = { - "method": "TaskAttrFilter" - } - e = MulTaskLearning( + task_update_decision=None, + unseen_task_allocation=None, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None + ): + + e = SeenTaskLearning( estimator=estimator, task_definition=task_definition, task_relationship_discovery=task_relationship_discovery, - task_mining=task_mining, + seen_task_allocation=task_allocation, task_remodeling=task_remodeling, - inference_integrate=inference_integrate) - self.unseen_task_detect = unseen_task_detect.get("method", - "TaskAttrFilter") - self.unseen_task_detect_param = e._parse_param( - unseen_task_detect.get("param", {}) + inference_integrate=inference_integrate ) + + self.unseen_sample_recognition = unseen_sample_recognition or { + "method": "SampleRegonitionDefault" + } + self.unseen_sample_recognition_param = e._parse_param( + self.unseen_sample_recognition.get("param", {})) + + self.unseen_sample_re_recognition = unseen_sample_re_recognition or { + "method": "SampleReRegonitionDefault" + } + self.unseen_sample_re_recognition_param = e._parse_param( + self.unseen_sample_re_recognition.get("param", {})) + + self.task_update_decision = task_update_decision or { + "method": "UpdateStrategyDefault" + } + self.task_update_decision_param = e._parse_param( + self.task_update_decision.get("param", {}) + ) + config = dict( ll_kb_server=Context.get_parameters("KB_SERVER"), - output_url=Context.get_parameters("OUTPUT_URL", "/tmp") - ) + output_url=Context.get_parameters( + "OUTPUT_URL", + "/tmp"), + cloud_output_url=Context.get_parameters( + "OUTPUT_URL", + "/tmp"), + edge_output_url=Context.get_parameters( + "EDGE_OUTPUT_URL", + KBResourceConstant.EDGE_KB_DIR.value), + task_index=KBResourceConstant.KB_INDEX_NAME.value) + + self.cloud_knowledge_management = CloudKnowledgeManagement( + config, estimator=e) + + self.edge_knowledge_management = EdgeKnowledgeManagement( + config, estimator=e) + + self.unseen_task_processing = UnseenTaskProcessing( + estimator, + config, + self.cloud_knowledge_management, + self.edge_knowledge_management, + unseen_task_allocation) + task_index = FileOps.join_path(config['output_url'], KBResourceConstant.KB_INDEX_NAME.value) + config['task_index'] = task_index super(LifelongLearning, self).__init__( estimator=e, config=config @@ -165,95 +224,111 @@ def train(self, train_data, if post_process is not None: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) - res, task_index_url = self.estimator.train( + res, seen_task_index = self.estimator.train( train_data=train_data, valid_data=valid_data, **kwargs ) # todo: Distinguishing incremental update and fully overwrite - if isinstance(task_index_url, str) and FileOps.exists(task_index_url): - task_index = FileOps.load(task_index_url) - else: - task_index = task_index_url + unseen_res, unseen_task_index = self.unseen_task_processing.initialize() - extractor = task_index['extractor'] - task_groups = task_index['task_groups'] + task_index = dict( + seen_task=seen_task_index, + unseen_task=unseen_task_index) + task_index_url = FileOps.dump( + task_index, self.cloud_knowledge_management.local_task_index_url) - model_upload_key = {} - for task in task_groups: - model_file = task.model.model - save_model = FileOps.join_path( - self.config.output_url, - os.path.basename(model_file) - ) - if model_file not in model_upload_key: - model_upload_key[model_file] = FileOps.upload(model_file, - save_model) - model_file = model_upload_key[model_file] - - try: - model = self.kb_server.upload_file(save_model) - except Exception as err: - self.log.error( - f"Upload task model of {model_file} fail: {err}" - ) - model = set_backend( - estimator=self.estimator.estimator.base_model - ) - model.load(model_file) - task.model.model = model - - for _task in task.tasks: - sample_dir = FileOps.join_path( - self.config.output_url, - f"{_task.samples.data_type}_{_task.entry}.sample") - task.samples.save(sample_dir) - try: - sample_dir = self.kb_server.upload_file(sample_dir) - except Exception as err: - self.log.error( - f"Upload task samples of {_task.entry} fail: {err}") - _task.samples.data_url = sample_dir - - save_extractor = FileOps.join_path( - self.config.output_url, - KBResourceConstant.TASK_EXTRACTOR_NAME.value - ) - extractor = FileOps.dump(extractor, save_extractor) - try: - extractor = self.kb_server.upload_file(extractor) - except Exception as err: - self.log.error(f"Upload task extractor fail: {err}") - task_info = { - "task_groups": task_groups, - "extractor": extractor - } - fd, name = tempfile.mkstemp() - FileOps.dump(task_info, name) - - index_file = self.kb_server.update_db(name) - if not index_file: - self.log.error(f"KB update Fail !") - index_file = name - FileOps.upload(index_file, self.config.task_index) + task_index = self.cloud_knowledge_management.update_kb( + task_index_url, self.kb_server) + res.update(unseen_res) task_info_res = self.estimator.model_info( - self.config.task_index, + task_index, relpath=self.config.data_path_prefix) self.report_task_info( None, K8sResourceKindStatus.COMPLETED.value, task_info_res) self.log.info(f"Lifelong learning Train task Finished, " - f"KB idnex save in {self.config.task_index}") + f"KB index save in {task_index}") return callback_func(self.estimator, res) if callback_func else res def update(self, train_data, valid_data=None, post_process=None, **kwargs): - return self.train( - train_data=train_data, - valid_data=valid_data, - post_process=post_process, - action="update", - **kwargs - ) + """ + fit for update the knowledge based on incremental data. + + Parameters + ---------- + train_data : BaseDataSource + Train data, see `sedna.datasources.BaseDataSource` for more detail. + valid_data : BaseDataSource + Valid data, BaseDataSource or None. + post_process : function + function or a registered method, callback after `estimator` train. + kwargs : Dict + parameters for `estimator` training, Like: + `early_stopping_rounds` in Xgboost.XGBClassifier + + Returns + ------- + train_history : object + """ + callback_func = None + if post_process is not None: + callback_func = ClassFactory.get_cls( + ClassType.CALLBACK, post_process) + + task_index_url = self.get_parameters( + "CLOUD_KB_INDEX", self.cloud_knowledge_management.task_index) + index_url = self.cloud_knowledge_management.local_task_index_url + FileOps.download(task_index_url, index_url) + + unseen_sample_re_recognition = ClassFactory.get_cls( + ClassType.UTD, self.unseen_sample_re_recognition["method"])( + index_url, **self.unseen_sample_re_recognition_param) + + seen_samples, unseen_samples = unseen_sample_re_recognition(train_data) + + # TODO: retrain temporarily + # historical_data = self._fetch_historical_data(index_url) + # seen_samples.x = np.concatenate( + # (historical_data.x, seen_samples.x, unseen_samples.x), axis=0) + # seen_samples.y = np.concatenate( + # (historical_data.y, seen_samples.y, unseen_samples.y), axis=0) + + seen_samples.x = np.concatenate((seen_samples.x, unseen_samples.x), axis=0) + seen_samples.y = np.concatenate((seen_samples.y, unseen_samples.y), axis=0) + + task_update_decision = ClassFactory.get_cls( + ClassType.KM, self.task_update_decision["method"])( + index_url, **self.task_update_decision_param) + + tasks, task_update_strategies = task_update_decision( + seen_samples, task_type="seen_task") + seen_task_index = self.cloud_knowledge_management.estimator.update( + tasks, task_update_strategies, task_index=index_url) + + tasks, task_update_strategies = task_update_decision( + unseen_samples, task_type="unseen_task") + unseen_task_index = self.unseen_task_processing.update( + tasks, task_update_strategies, task_index=index_url) + + task_index = { + "seen_task": seen_task_index, + "unseen_task": unseen_task_index, + } + + task_index = self.cloud_knowledge_management.update_kb( + task_index, self.kb_server) + + task_info_res = self.estimator.model_info( + task_index, + relpath=self.config.data_path_prefix) + + self.report_task_info( + None, K8sResourceKindStatus.COMPLETED.value, task_info_res) + self.log.info(f"Lifelong learning Update task Finished, " + f"KB index save in {task_index}") + return callback_func(self.estimator, + task_index) if callback_func else task_index def evaluate(self, data, post_process=None, **kwargs): """ @@ -275,55 +350,23 @@ def evaluate(self, data, post_process=None, **kwargs): elif post_process is not None: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) - task_index_url = self.get_parameters( - "MODEL_URLS", self.config.task_index) - index_url = self.estimator.estimator.task_index_url + + task_index_url = Context.get_parameters( + "MODEL_URLS", self.cloud_knowledge_management.task_index) + index_url = self.cloud_knowledge_management.local_task_index_url self.log.info( f"Download kb index from {task_index_url} to {index_url}") FileOps.download(task_index_url, index_url) - res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) - drop_tasks = [] - - model_filter_operator = self.get_parameters("operator", ">") - model_threshold = float(self.get_parameters('model_threshold', 0.1)) - - operator_map = { - ">": lambda x, y: x > y, - "<": lambda x, y: x < y, - "=": lambda x, y: x == y, - ">=": lambda x, y: x >= y, - "<=": lambda x, y: x <= y, - } - if model_filter_operator not in operator_map: - self.log.warn( - f"operator {model_filter_operator} use to " - f"compare is not allow, set to <" - ) - model_filter_operator = "<" - operator_func = operator_map[model_filter_operator] - - for detail in tasks_detail: - scores = detail.scores - entry = detail.entry - self.log.info(f"{entry} scores: {scores}") - if any(map(lambda x: operator_func(float(x), - model_threshold), - scores.values())): - self.log.warn( - f"{entry} will not be deploy because all " - f"scores {model_filter_operator} {model_threshold}") - drop_tasks.append(entry) - continue - drop_task = ",".join(drop_tasks) - index_file = self.kb_server.update_task_status(drop_task, new_status=0) - if not index_file: - self.log.error(f"KB update Fail !") - index_file = str(index_url) + + res, index_file = self._task_evaluation( + data, task_index=index_url, **kwargs) + self.log.info("Task evaluation finishes.") + + FileOps.upload(index_file, self.cloud_knowledge_management.task_index) self.log.info( - f"upload kb index from {index_file} to {self.config.task_index}") - FileOps.upload(index_file, self.config.task_index) + f"upload kb index from {index_file} to {self.cloud_knowledge_management.task_index}") task_info_res = self.estimator.model_info( - self.config.task_index, result=res, + self.cloud_knowledge_management.task_index, result=res, relpath=self.config.data_path_prefix) self.report_task_info( None, @@ -350,36 +393,79 @@ def inference(self, data=None, post_process=None, **kwargs): Returns ------- - result : array_like - results array, contain all inference results in each sample. - is_unseen_task : bool - `true` means detect an unseen task, `false` means not - tasks : List - tasks assigned to each sample. """ - task_index_url = self.get_parameters( - "MODEL_URLS", self.config.task_index) - index_url = self.estimator.estimator.task_index_url - FileOps.download(task_index_url, index_url) - res, tasks = self.estimator.predict( - data=data, post_process=post_process, **kwargs - ) + seen_res, unseen_res = None, None + task_index_url = Context.get_parameters( + "MODEL_URLS", self.cloud_knowledge_management.task_index) + index_url = self.edge_knowledge_management.task_index + if not FileOps.exists(index_url): + FileOps.download(task_index_url, index_url) + self.log.info( + f"Download kb index from {task_index_url} to {index_url}") + + self.edge_knowledge_management.update_kb(index_url) + self.log.info(f"Tasks are deployed at the edge.") + + unseen_sample_recognition = ClassFactory.get_cls( + ClassType.UTD, + self.unseen_sample_recognition["method"])( + self.edge_knowledge_management.task_index, + **self.unseen_sample_recognition_param) + + seen_samples, unseen_samples = unseen_sample_recognition(data, **kwargs) + if unseen_samples.x is not None and len(unseen_samples.x) > 0: + self.edge_knowledge_management.log.info( + f"Unseen task is detected.") + unseen_res, unseen_tasks = self.unseen_task_processing.predict( + unseen_samples) + + unseen_save_url = self.edge_knowledge_management.save_unseen_samples( + unseen_samples, post_process=post_process) + self.log.info( + f"Unseen samples are being uploaded to {unseen_save_url}.") + + if seen_samples.x is not None and len(seen_samples.x) > 0: + seen_res, seen_tasks = self.edge_knowledge_management.estimator.predict( + data=seen_samples, post_process=post_process, + task_index=index_url, + task_type="seen_task", + **kwargs + ) + + return seen_res, unseen_res - is_unseen_task = False - if self.unseen_task_detect: - - try: - if callable(self.unseen_task_detect): - unseen_task_detect_algorithm = self.unseen_task_detect() - else: - unseen_task_detect_algorithm = ClassFactory.get_cls( - ClassType.UTD, self.unseen_task_detect - )() - except ValueError as err: - self.log.error("Lifelong learning " - "Inference [UTD] : {}".format(err)) + def _task_evaluation(self, data, **kwargs): + res, tasks_detail = self.cloud_knowledge_management.estimator.evaluate( + data=data, **kwargs) + drop_task = self.cloud_knowledge_management.evaluate_tasks( + tasks_detail, **kwargs) + + index_file = self.kb_server.update_task_status(drop_task, new_status=0) + + if not index_file: + self.log.error(f"KB update Fail !") + index_file = str( + self.cloud_knowledge_management.local_task_index_url) + else: + self.log.info(f"Deploy {index_file} to the edge.") + + return res, index_file + + def _fetch_historical_data(self, task_index): + if isinstance(task_index, str): + task_index = FileOps.load(task_index) + + samples = BaseDataSource(data_type="train") + + for task_group in task_index["seen_task"]["task_groups"]: + if isinstance(task_group.samples, BaseDataSource): + _samples = task_group.samples else: - is_unseen_task = unseen_task_detect_algorithm( - tasks=tasks, result=res, **self.unseen_task_detect_param - ) - return res, is_unseen_task, tasks + _samples = FileOps.load(task_group.samples.data_url) + + samples.x = _samples.x if samples.x is None else np.concatenate( + (samples.x, _samples.x), axis=0) + samples.y = _samples.y if samples.y is None else np.concatenate( + (samples.y, _samples.y), axis=0) + + return samples diff --git a/lib/sedna/datasources/__init__.py b/lib/sedna/datasources/__init__.py index 8190b76f3..cc0d2a36c 100644 --- a/lib/sedna/datasources/__init__.py +++ b/lib/sedna/datasources/__init__.py @@ -106,7 +106,6 @@ def parse(self, *args, **kwargs): self.x = np.array(x_data) self.y = np.array(y_data) - class CSVDataParse(BaseDataSource, ABC): """ csv file which contain Structured Data parser @@ -150,3 +149,42 @@ def parse(self, *args, **kwargs): return self.x = pd.concat(x_data) self.y = pd.concat(y_data) + +class IndexDataParse(BaseDataSource, ABC): + """ + txt file which contain image list parser + """ + + def __init__(self, data_type, func=None): + super(IndexDataParse, self).__init__(data_type=data_type, func=func) + + def parse(self, *args, **kwargs): + x_data = [] + y_data = [] + use_raw = kwargs.get("use_raw") + for f in args: + if not (f and FileOps.exists(f)): + continue + with open(f) as fin: + if self.process_func: + res = [] + for line in fin.readlines(): + lines = line.strip().split() + lines = [self.process_func(data) for data in lines] + res.append(lines) + else: + res = [line.strip().split() for line in fin.readlines()] + for tup in res: + if not len(tup): + continue + if use_raw: + x_data.append(tup) + else: + x_data.append(tup[:-1]) + if not self.is_test_data: + if len(tup) > 1: + y_data.append(tup[1]) + else: + y_data.append(0) + self.x = np.array(x_data) + self.y = np.array(y_data) \ No newline at end of file diff --git a/lib/sedna/service/server/knowledgeBase/server.py b/lib/sedna/service/server/knowledgeBase/server.py index a489e0588..850ea0c2a 100644 --- a/lib/sedna/service/server/knowledgeBase/server.py +++ b/lib/sedna/service/server/knowledgeBase/server.py @@ -54,6 +54,10 @@ def __init__(self, host: str, http_port: int = 8080, self.save_dir = FileOps.clean_folder([save_dir], clean=False)[0] self.url = f"{self.url}/{servername}" self.kb_index = KBResourceConstant.KB_INDEX_NAME.value + self.seen_task_key = KBResourceConstant.SEEN_TASK.value + self.unseen_task_key = KBResourceConstant.UNSEEN_TASK.value + self.task_group_key = KBResourceConstant.TASK_GROUPS.value + self.extractor_key = KBResourceConstant.EXTRACTOR.value self.app = FastAPI( routes=[ APIRoute( @@ -120,7 +124,7 @@ async def file_upload(self, file: UploadFile = File(...)): return f"/file/download?files={filename}&name={filename}" def update_status(self, data: KBUpdateResult = Body(...)): - deploy = True if data.status else False + deploy = bool(data.status) tasks = data.tasks.split(",") if data.tasks else [] with Session(bind=engine) as session: session.query(TaskGrp).filter( @@ -131,18 +135,24 @@ def update_status(self, data: KBUpdateResult = Body(...)): # todo: get from kb _index_path = FileOps.join_path(self.save_dir, self.kb_index) - task_info = joblib.load(_index_path) + try: + task_info = joblib.load(_index_path) + except Exception as err: + print(f"{err} And return None.") + return None + new_task_group = [] - default_task = task_info["task_groups"][0] + # TODO: to fit seen tasks and unseen tasks + default_task = task_info[self.seen_task_key][self.task_group_key][0] # todo: get from transfer learning - for task_group in task_info["task_groups"]: + for task_group in task_info[self.seen_task_key][self.task_group_key]: if not ((task_group.entry in tasks) == deploy): new_task_group.append(default_task) continue new_task_group.append(task_group) - task_info["task_groups"] = new_task_group - _index_path = FileOps.join_path(self.save_dir, self.kb_index) + task_info[self.seen_task_key][self.task_group_key] = new_task_group + FileOps.dump(task_info, _index_path) return f"/file/download?files={self.kb_index}&name={self.kb_index}" @@ -153,9 +163,14 @@ def update(self, task: UploadFile = File(...)): fout.write(tasks) os.close(fd) upload_info = joblib.load(name) + # TODO: to adapt unseen tasks + task_groups = upload_info[self.seen_task_key][self.task_group_key] + task_groups.extend(upload_info[self.unseen_task_key][self.task_group_key]) with Session(bind=engine) as session: - for task_group in upload_info["task_groups"]: + # TODO: to adapt unseen tasks + # for task_group in upload_info["task_groups"]: + for task_group in task_groups: grp, g_create = get_or_create( session=session, model=TaskGrp, name=task_group.entry) if g_create: