From 68779be9d847cd367027994fcfa17d393ddad81d Mon Sep 17 00:00:00 2001 From: Albert Wang Date: Wed, 20 Nov 2024 01:03:51 +0800 Subject: [PATCH] clean group weight --- configs/unifuse_st3d.cfg | 5 +-- train.py | 12 +++--- utils/get_args.py | 5 --- utils/utils.py | 79 ---------------------------------------- 4 files changed, 7 insertions(+), 94 deletions(-) diff --git a/configs/unifuse_st3d.cfg b/configs/unifuse_st3d.cfg index 966e82d..33bebeb 100644 --- a/configs/unifuse_st3d.cfg +++ b/configs/unifuse_st3d.cfg @@ -1,5 +1,6 @@ [EXP_SETTING] -id=unifuse_st3d_pilot +# id=unifuse_st3d_pilot +id=unifuse_st3d_pilot_group # #### Dataset & Augmentation setting #### root=/project/albert/datasets/planar/mp3d @@ -15,8 +16,6 @@ unlabel_train_txt=/project/albert/datasets/planar/st3d/st3d_all.txt zero_shot_root=/project/albert/datasets/planar/stanford2D3D zero_shot_txt=/project/albert/datasets/planar/sf3d_txt/test.txt -# relative or metric -# relative=CUBE median_align=True dataset=mp3d epochs=200 diff --git a/train.py b/train.py index 716ab4d..b6c3e25 100644 --- a/train.py +++ b/train.py @@ -320,15 +320,14 @@ def val( epoch: int, writer: SummaryWriter=None, evaluator: Affine_Inv_Evaluator=None, - mode='Valid', - save_log_flag: bool=True): + mode='Valid'): """Eval function.""" model.eval() with torch.no_grad(): pbar = tqdm(dataloader) pbar.set_description("{} Epoch_{}".format(mode, epoch)) total_loss = defaultdict(float) - for batch_idx, inputs in enumerate(pbar): + for _, inputs in enumerate(pbar): for key, ipt in inputs.items(): if key not in ["rgb", "rgb_name"]: inputs[key] = ipt.to(args.device) @@ -352,11 +351,10 @@ def val( total_loss[f'total_{k}'] += v.data.cpu().numpy() / len(dataloader) # Evaluator to shell and tensorboard - for i, key in enumerate(evaluator.metrics.keys()): + for _, key in enumerate(evaluator.metrics.keys()): total_loss[f'total_{key}'] = np.array(evaluator.metrics[key].avg.cpu()) evaluator.print() - if save_log_flag: - save_log(writer, inputs, outputs, total_loss, args) + save_log(writer, inputs, outputs, total_loss, args) def main_joint_unlabel(): @@ -410,7 +408,7 @@ def main_joint_unlabel(): val(args, model, test_loader, epoch, writer['test'], evaluator, mode='Test') if zeroshot_loader is not None: zeroshot_evaluator.reset_eval_metrics() - val(args, model, zeroshot_loader, epoch, writer['zeroshot'], zeroshot_evaluator, mode='zeroshot') + val(args, model, zeroshot_loader, epoch, writer['zeroshot'], zeroshot_evaluator, mode='Zeroshot') if (epoch + 1) % args.save_every == 0: save_model(model, optimizer, args) diff --git a/utils/get_args.py b/utils/get_args.py index b8fa5b4..81f7529 100755 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -1,9 +1,5 @@ -import os -import sys -import time import argparse import configparser -from threading import Thread def force_config_value_type(val): @@ -102,7 +98,6 @@ def parse_args(): """ Pre-processing setting """ - # parser.add_argument('--base_height', default=512, type=int) parser.add_argument('--h', default=512, type=int, help='loader process height') parser.add_argument('--w', default=1024, type=int, help='loader process width') parser.add_argument('--rgb_mean', default=[0.485, 0.456, 0.406], nargs=3, type=float) diff --git a/utils/utils.py b/utils/utils.py index b1a08c4..03f86ed 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,8 +1,3 @@ -import os -import pickle - -import torch -import torch.nn as nn import numpy as np @@ -16,13 +11,11 @@ def read_list(list_file): def read_list_with_ndarray(list_file): - # rgb_depth_list = np.empty((0, 2)) rgb_list = np.empty(0) depth_list = np.empty(0) with open(list_file) as f: lines = f.readlines() for line in lines: - # rgb_depth_list = np.append(rgb_depth_list, [line.strip().split(" ")], 0) rgb, depth = line.strip().split(" ") rgb_list = np.append(rgb_list, rgb) depth_list = np.append(depth_list, depth) @@ -36,85 +29,13 @@ def read_list_with_ndarray(list_file): def group_weight(module): group_decay = [] group_no_decay = [] - # for m in module.modules(): - # if isinstance(m, nn.Linear): - # group_decay.append(m.weight) - # if m.bias is not None: - # group_no_decay.append(m.bias) - # elif isinstance(m, nn.modules.conv._ConvNd): - # group_decay.append(m.weight) - # if m.bias is not None: - # group_no_decay.append(m.bias) - # elif isinstance(m, nn.modules.batchnorm._BatchNorm): - # if m.weight is not None: - # group_no_decay.append(m.weight) - # if m.bias is not None: - # group_no_decay.append(m.bias) - # elif isinstance(m, nn.GroupNorm): - # if m.weight is not None: - # group_no_decay.append(m.weight) - # if m.bias is not None: - # group_no_decay.append(m.bias) - # elif isinstance(m, nn.LayerNorm): - # if m.weight is not None: - # group_no_decay.append(m.weight) - # if m.bias is not None: - # group_no_decay.append(m.bias) - # else: - # import pdb - # # if m.bias is not None: - # # group_no_decay.append(m.bias) for name, param in module.named_parameters(): if name.endswith('weight'): group_decay.append(param) elif name.endswith('bias'): group_no_decay.append(param) - # assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) return [ dict(params=[param for param in group_decay if param.requires_grad]), dict(params=[param for param in group_no_decay if param.requires_grad], weight_decay=.0) ] - - -# Some recommendations for memory leak (not solution for this time) -# https://github.com/Lightning-AI/pytorch-lightning/issues/17257 -# https://github.com/ppwwyyxx/RAM-multiprocess-dataloader/blob/79897b26a2c4185a3ed086f18be5ea300913d5b7/serialize.py#L40-L50 -class NumpySerializedList(): - def __init__(self, lst: list): - def _serialize(data): - buffer = pickle.dumps(data, protocol=-1) - return np.frombuffer(buffer, dtype=np.uint8) - - print( - "Serializing {} elements to byte tensors and concatenating them all ...".format( - len(lst) - ) - ) - self._lst = [_serialize(x) for x in lst] - self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64) - self._addr = np.cumsum(self._addr) - self._lst = np.concatenate(self._lst) - print("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2)) - - def __len__(self): - return len(self._addr) - - def __getitem__(self, idx): - start_addr = 0 if idx == 0 else self._addr[idx - 1].item() - end_addr = self._addr[idx].item() - bytes = memoryview(self._lst[start_addr:end_addr]) - return pickle.loads(bytes) - - -class TorchSerializedList(NumpySerializedList): - def __init__(self, lst: list): - super().__init__(lst) - self._addr = torch.from_numpy(self._addr) - self._lst = torch.from_numpy(self._lst) - - def __getitem__(self, idx): - start_addr = 0 if idx == 0 else self._addr[idx - 1].item() - end_addr = self._addr[idx].item() - bytes = memoryview(self._lst[start_addr:end_addr].numpy()) - return pickle.loads(bytes)