Skip to content

Commit

Permalink
clean group weight
Browse files Browse the repository at this point in the history
  • Loading branch information
albert100121 committed Nov 19, 2024
1 parent da77632 commit 68779be
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 94 deletions.
5 changes: 2 additions & 3 deletions configs/unifuse_st3d.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[EXP_SETTING]
id=unifuse_st3d_pilot
# id=unifuse_st3d_pilot
id=unifuse_st3d_pilot_group

# #### Dataset & Augmentation setting ####
root=/project/albert/datasets/planar/mp3d
Expand All @@ -15,8 +16,6 @@ unlabel_train_txt=/project/albert/datasets/planar/st3d/st3d_all.txt
zero_shot_root=/project/albert/datasets/planar/stanford2D3D
zero_shot_txt=/project/albert/datasets/planar/sf3d_txt/test.txt

# relative or metric
# relative=CUBE
median_align=True
dataset=mp3d
epochs=200
Expand Down
12 changes: 5 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,14 @@ def val(
epoch: int,
writer: SummaryWriter=None,
evaluator: Affine_Inv_Evaluator=None,
mode='Valid',
save_log_flag: bool=True):
mode='Valid'):
"""Eval function."""
model.eval()
with torch.no_grad():
pbar = tqdm(dataloader)
pbar.set_description("{} Epoch_{}".format(mode, epoch))
total_loss = defaultdict(float)
for batch_idx, inputs in enumerate(pbar):
for _, inputs in enumerate(pbar):
for key, ipt in inputs.items():
if key not in ["rgb", "rgb_name"]:
inputs[key] = ipt.to(args.device)
Expand All @@ -352,11 +351,10 @@ def val(
total_loss[f'total_{k}'] += v.data.cpu().numpy() / len(dataloader)

# Evaluator to shell and tensorboard
for i, key in enumerate(evaluator.metrics.keys()):
for _, key in enumerate(evaluator.metrics.keys()):
total_loss[f'total_{key}'] = np.array(evaluator.metrics[key].avg.cpu())
evaluator.print()
if save_log_flag:
save_log(writer, inputs, outputs, total_loss, args)
save_log(writer, inputs, outputs, total_loss, args)


def main_joint_unlabel():
Expand Down Expand Up @@ -410,7 +408,7 @@ def main_joint_unlabel():
val(args, model, test_loader, epoch, writer['test'], evaluator, mode='Test')
if zeroshot_loader is not None:
zeroshot_evaluator.reset_eval_metrics()
val(args, model, zeroshot_loader, epoch, writer['zeroshot'], zeroshot_evaluator, mode='zeroshot')
val(args, model, zeroshot_loader, epoch, writer['zeroshot'], zeroshot_evaluator, mode='Zeroshot')

if (epoch + 1) % args.save_every == 0:
save_model(model, optimizer, args)
Expand Down
5 changes: 0 additions & 5 deletions utils/get_args.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import os
import sys
import time
import argparse
import configparser
from threading import Thread


def force_config_value_type(val):
Expand Down Expand Up @@ -102,7 +98,6 @@ def parse_args():
"""
Pre-processing setting
"""
# parser.add_argument('--base_height', default=512, type=int)
parser.add_argument('--h', default=512, type=int, help='loader process height')
parser.add_argument('--w', default=1024, type=int, help='loader process width')
parser.add_argument('--rgb_mean', default=[0.485, 0.456, 0.406], nargs=3, type=float)
Expand Down
79 changes: 0 additions & 79 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import os
import pickle

import torch
import torch.nn as nn
import numpy as np


Expand All @@ -16,13 +11,11 @@ def read_list(list_file):


def read_list_with_ndarray(list_file):
# rgb_depth_list = np.empty((0, 2))
rgb_list = np.empty(0)
depth_list = np.empty(0)
with open(list_file) as f:
lines = f.readlines()
for line in lines:
# rgb_depth_list = np.append(rgb_depth_list, [line.strip().split(" ")], 0)
rgb, depth = line.strip().split(" ")
rgb_list = np.append(rgb_list, rgb)
depth_list = np.append(depth_list, depth)
Expand All @@ -36,85 +29,13 @@ def read_list_with_ndarray(list_file):
def group_weight(module):
group_decay = []
group_no_decay = []
# for m in module.modules():
# if isinstance(m, nn.Linear):
# group_decay.append(m.weight)
# if m.bias is not None:
# group_no_decay.append(m.bias)
# elif isinstance(m, nn.modules.conv._ConvNd):
# group_decay.append(m.weight)
# if m.bias is not None:
# group_no_decay.append(m.bias)
# elif isinstance(m, nn.modules.batchnorm._BatchNorm):
# if m.weight is not None:
# group_no_decay.append(m.weight)
# if m.bias is not None:
# group_no_decay.append(m.bias)
# elif isinstance(m, nn.GroupNorm):
# if m.weight is not None:
# group_no_decay.append(m.weight)
# if m.bias is not None:
# group_no_decay.append(m.bias)
# elif isinstance(m, nn.LayerNorm):
# if m.weight is not None:
# group_no_decay.append(m.weight)
# if m.bias is not None:
# group_no_decay.append(m.bias)
# else:
# import pdb
# # if m.bias is not None:
# # group_no_decay.append(m.bias)

for name, param in module.named_parameters():
if name.endswith('weight'):
group_decay.append(param)
elif name.endswith('bias'):
group_no_decay.append(param)
# assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
return [
dict(params=[param for param in group_decay if param.requires_grad]),
dict(params=[param for param in group_no_decay if param.requires_grad], weight_decay=.0)
]


# Some recommendations for memory leak (not solution for this time)
# https://github.com/Lightning-AI/pytorch-lightning/issues/17257
# https://github.com/ppwwyyxx/RAM-multiprocess-dataloader/blob/79897b26a2c4185a3ed086f18be5ea300913d5b7/serialize.py#L40-L50
class NumpySerializedList():
def __init__(self, lst: list):
def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)

print(
"Serializing {} elements to byte tensors and concatenating them all ...".format(
len(lst)
)
)
self._lst = [_serialize(x) for x in lst]
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
self._addr = np.cumsum(self._addr)
self._lst = np.concatenate(self._lst)
print("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))

def __len__(self):
return len(self._addr)

def __getitem__(self, idx):
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
end_addr = self._addr[idx].item()
bytes = memoryview(self._lst[start_addr:end_addr])
return pickle.loads(bytes)


class TorchSerializedList(NumpySerializedList):
def __init__(self, lst: list):
super().__init__(lst)
self._addr = torch.from_numpy(self._addr)
self._lst = torch.from_numpy(self._lst)

def __getitem__(self, idx):
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
end_addr = self._addr[idx].item()
bytes = memoryview(self._lst[start_addr:end_addr].numpy())
return pickle.loads(bytes)

0 comments on commit 68779be

Please sign in to comment.