From d5eecd8dd34b203bccbe7245205a4df678990f35 Mon Sep 17 00:00:00 2001 From: GengDavid Date: Wed, 26 Jul 2023 11:26:13 +0800 Subject: [PATCH] first commit --- .DS_Store | Bin 0 -> 6148 bytes .gitignore | 4 + LICENSE | 21 ++ collect_env.py | 492 +++++++++++++++++++++++++ convs/__init__.py | 0 convs/cifar_resnet.py | 198 ++++++++++ convs/linears.py | 317 ++++++++++++++++ convs/resnet.py | 364 +++++++++++++++++++ convs/vits.py | 666 ++++++++++++++++++++++++++++++++++ evaluator.py | 103 ++++++ exps/slca_cars.json | 19 + exps/slca_cars_mocov3.json | 19 + exps/slca_cifar.json | 19 + exps/slca_cifar_mocov3.json | 19 + exps/slca_cub.json | 19 + exps/slca_cub_mocov3.json | 19 + exps/slca_imgnetr.json | 19 + exps/slca_imgnetr_mocov3.json | 19 + main.py | 33 ++ models/__init__.py | 0 models/base.py | 403 ++++++++++++++++++++ models/slca.py | 258 +++++++++++++ split_car.py | 21 ++ split_cub.py | 16 + train_all.sh | 10 + trainer.py | 115 ++++++ utils/__init__.py | 0 utils/buffer.py | 225 ++++++++++++ utils/cutmix.py | 40 ++ utils/data.py | 256 +++++++++++++ utils/data_manager.py | 245 +++++++++++++ utils/factory.py | 9 + utils/inc_net.py | 587 ++++++++++++++++++++++++++++++ utils/net_linear_wapper.py | 14 + utils/toolkit.py | 58 +++ 35 files changed, 4607 insertions(+) create mode 100644 .DS_Store create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 collect_env.py create mode 100644 convs/__init__.py create mode 100644 convs/cifar_resnet.py create mode 100644 convs/linears.py create mode 100644 convs/resnet.py create mode 100644 convs/vits.py create mode 100644 evaluator.py create mode 100644 exps/slca_cars.json create mode 100644 exps/slca_cars_mocov3.json create mode 100644 exps/slca_cifar.json create mode 100644 exps/slca_cifar_mocov3.json create mode 100644 exps/slca_cub.json create mode 100644 exps/slca_cub_mocov3.json create mode 100644 exps/slca_imgnetr.json create mode 100644 exps/slca_imgnetr_mocov3.json create mode 100644 main.py create mode 100644 models/__init__.py create mode 100644 models/base.py create mode 100644 models/slca.py create mode 100644 split_car.py create mode 100644 split_cub.py create mode 100644 train_all.sh create mode 100644 trainer.py create mode 100644 utils/__init__.py create mode 100644 utils/buffer.py create mode 100644 utils/cutmix.py create mode 100644 utils/data.py create mode 100644 utils/data_manager.py create mode 100644 utils/factory.py create mode 100644 utils/inc_net.py create mode 100644 utils/net_linear_wapper.py create mode 100644 utils/toolkit.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..45703fe73012b6f63822185c23da12dcf343603e GIT binary patch literal 6148 zcmeHKJ5Izv41IH+966y0k&knjK}t5 zUJ=ItWOMAVfi-{?RS_R*=A!$mGdT-SOBAbd=nwnhejJBc?2fa8}!)W z6}!`-4Ic20mpbwSPb|=LYp6F`99iJ;r+evT(w__@1Ia)#kPQ3-2H3MzR+o-xlYwL) z8Te*EzYm3~*ac3G_UWLp5rEjxZNg`pC5S~G#4d1hq=zD&O7v8T5kowk{t|Ir;NNxC4Acyq`*N-Q|Ax9S|JO;mN(PdFKgEDF+uQAiSBl;` yyPWpgLcODY31cms##%AaS}_;eimy)Miav9{3!EG+oqkIv=8u5#l9CMk1_PhE?kR=< literal 0 HcmV?d00001 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ee3716a --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +data/ +logs/ +__pycache__/ +*.pyc \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d1e9fc2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Fu-Yun Wang. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/collect_env.py b/collect_env.py new file mode 100644 index 0000000..97bbc7b --- /dev/null +++ b/collect_env.py @@ -0,0 +1,492 @@ +from __future__ import print_function + +# Unlike the rest of the PyTorch this file must be python2 compliant. +# This script outputs relevant system environment info +# Run it with `python collect_env.py`. +import datetime +import locale +import re +import subprocess +import sys +import os +from collections import namedtuple + + +try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +# System Environment Information +SystemEnv = namedtuple('SystemEnv', [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'libc_version', + 'python_version', + 'python_platform', + 'is_cuda_available', + 'cuda_runtime_version', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', + 'caching_allocator_config', + 'is_xnnpack_available', +]) + + +def run(command): + """Returns (return-code, stdout, stderr)""" + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=True) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == 'win32': + enc = 'oem' + else: + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + err = raw_err.decode(enc) + return rc, output.strip(), err.strip() + + +def run_and_read_all(run_lambda, command): + """Runs command using run_lambda; reads and returns entire output if rc is 0""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Runs command using run_lambda, returns the first regex match if it exists""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + +def run_and_return_first_line(run_lambda, command): + """Runs command using run_lambda and returns first line if output is not empty""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out.split('\n')[0] + + +def get_conda_packages(run_lambda): + conda = os.environ.get('CONDA_EXE', 'conda') + out = run_and_read_all(run_lambda, "{} list".format(conda)) + if out is None: + return out + + return "\n".join( + line + for line in out.splitlines() + if not line.startswith("#") + and any( + name in line + for name in { + "torch", + "numpy", + "cudatoolkit", + "soumith", + "mkl", + "magma", + "mkl", + } + ) + ) + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + +def get_clang_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): + if TORCH_AVAILABLE and torch.cuda.is_available(): + return torch.cuda.get_device_name(None) + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') + + +def get_cudnn_version(run_lambda): + """This will return a list of libcudnn.so; it's hard to tell which one is being used""" + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") + where_cmd = os.path.join(system_root, 'System32', 'where') + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == 'darwin': + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install + # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + l = os.environ.get('CUDNN_LIBRARY') + if l is not None and os.path.isfile(l): + return os.path.realpath(l) + return None + files_set = set() + for fn in out.split('\n'): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files_set.add(fn) + if not files_set: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = list(sorted(files_set)) + if len(files) == 1: + return files[0] + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = 'nvidia-smi' + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = '"{}"'.format(candidate_smi) + break + return smi + + +def get_platform(): + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + + +def get_windows_version(run_lambda): + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') + findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + from platform import machine + platform = get_platform() + + if platform == 'win32' or platform == 'cygwin': + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'macOS {} ({})'.format(version, machine()) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + return '{} ({})'.format(platform, machine()) + + # Unknown platform + return platform + + +def get_python_platform(): + import platform + return platform.platform() + + +def get_libc_version(): + import platform + if get_platform() != 'linux': + return 'N/A' + return '-'.join(platform.libc_ver()) + + +def get_pip_packages(run_lambda): + """Returns `pip list` output. Note: will also find conda-installed pytorch + and numpy packages.""" + # People generally have `pip` as `pip` or `pip3` + # But here it is incoved as `python -mpip` + def run_with_pip(pip): + out = run_and_read_all(run_lambda, "{} list --format=freeze".format(pip)) + return "\n".join( + line + for line in out.splitlines() + if any( + name in line + for name in { + "torch", + "numpy", + "mypy", + } + ) + ) + + pip_version = 'pip3' if sys.version[0] == '3' else 'pip' + out = run_with_pip(sys.executable + ' -mpip') + + return pip_version, out + + +def get_cachingallocator_config(): + ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + return ca_config + +def is_xnnpack_available(): + if TORCH_AVAILABLE: + import torch.backends.xnnpack + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + else: + return "N/A" + +def get_env_info(): + run_lambda = run + pip_version, pip_list_output = get_pip_packages(run_lambda) + + if TORCH_AVAILABLE: + version_str = torch.__version__ + debug_mode_str = str(torch.version.debug) + cuda_available_str = str(torch.cuda.is_available()) + cuda_version_str = torch.version.cuda + if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + else: # HIP version + cfg = torch._C._show_config().split('\n') + hip_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'HIP Runtime' in s][0] + miopen_runtime_version = [s.rsplit(None, 1)[-1] for s in cfg if 'MIOpen' in s][0] + cuda_version_str = 'N/A' + hip_compiled_version = torch.version.hip + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + + sys_version = sys.version.replace("\n", " ") + + return SystemEnv( + torch_version=version_str, + is_debug_build=debug_mode_str, + python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), + python_platform=get_python_platform(), + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version=hip_compiled_version, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, + pip_version=pip_version, + pip_packages=pip_list_output, + conda_packages=get_conda_packages(run_lambda), + os=get_os(run_lambda), + libc_version=get_libc_version(), + gcc_version=get_gcc_version(run_lambda), + clang_version=get_clang_version(run_lambda), + cmake_version=get_cmake_version(run_lambda), + caching_allocator_config=get_cachingallocator_config(), + is_xnnpack_available=is_xnnpack_available(), + ) + +env_info_fmt = """ +PyTorch version: {torch_version} +Is debug build: {is_debug_build} +CUDA used to build PyTorch: {cuda_compiled_version} +ROCM used to build PyTorch: {hip_compiled_version} + +OS: {os} +GCC version: {gcc_version} +Clang version: {clang_version} +CMake version: {cmake_version} +Libc version: {libc_version} + +Python version: {python_version} +Python platform: {python_platform} +Is CUDA available: {is_cuda_available} +CUDA runtime version: {cuda_runtime_version} +GPU models and configuration: {nvidia_gpu_models} +Nvidia driver version: {nvidia_driver_version} +cuDNN version: {cudnn_version} +HIP runtime version: {hip_runtime_version} +MIOpen runtime version: {miopen_runtime_version} +Is XNNPACK available: {is_xnnpack_available} + +Versions of relevant libraries: +{pip_packages} +{conda_packages} +""".strip() + + +def pretty_str(envinfo): + def replace_nones(dct, replacement='Could not collect'): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true='Yes', false='No'): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag='[prepend]'): + lines = text.split('\n') + updated_lines = [tag + line for line in lines] + return '\n'.join(updated_lines) + + def replace_if_empty(text, replacement='No relevant packages'): + if text is not None and len(text) == 0: + return replacement + return text + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + if envinfo.cuda_compiled_version is None: + mutable_dict['cuda_compiled_version'] = 'None' + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + # If either of these are '', replace with 'No relevant packages' + mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict['pip_packages']: + mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], + '[{}] '.format(envinfo.pip_version)) + if mutable_dict['conda_packages']: + mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], + '[conda] ') + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): + minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR + if sys.platform == "linux" and os.path.exists(minidump_dir): + dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] + latest = max(dumps, key=os.path.getctime) + ctime = os.path.getctime(latest) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') + msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ + "if this is related to your bug please include it when you file a report ***" + print(msg, file=sys.stderr) + + + +if __name__ == '__main__': + main() diff --git a/convs/__init__.py b/convs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/convs/cifar_resnet.py b/convs/cifar_resnet.py new file mode 100644 index 0000000..8ce7d4d --- /dev/null +++ b/convs/cifar_resnet.py @@ -0,0 +1,198 @@ +''' +Reference: +https://github.com/khurramjaved96/incremental-learning/blob/autoencoders/model/resnet32.py +''' +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DownsampleA(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleA, self).__init__() + assert stride == 2 + self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) + + def forward(self, x): + x = self.avg(x) + return torch.cat((x, x.mul(0)), 1) + + +class DownsampleB(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleB, self).__init__() + self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) + self.bn = nn.BatchNorm2d(nOut) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class DownsampleC(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleC, self).__init__() + assert stride != 1 or nIn != nOut + self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) + + def forward(self, x): + x = self.conv(x) + return x + + +class DownsampleD(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleD, self).__init__() + assert stride == 2 + self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) + self.bn = nn.BatchNorm2d(nOut) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class ResNetBasicblock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(ResNetBasicblock, self).__init__() + + self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn_a = nn.BatchNorm2d(planes) + + self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_b = nn.BatchNorm2d(planes) + + self.downsample = downsample + + def forward(self, x): + residual = x + + basicblock = self.conv_a(x) + basicblock = self.bn_a(basicblock) + basicblock = F.relu(basicblock, inplace=True) + + basicblock = self.conv_b(basicblock) + basicblock = self.bn_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(x) + + return F.relu(residual + basicblock, inplace=True) + + +class CifarResNet(nn.Module): + """ + ResNet optimized for the Cifar Dataset, as specified in + https://arxiv.org/abs/1512.03385.pdf + """ + + def __init__(self, block, depth, channels=3): + super(CifarResNet, self).__init__() + + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' + layer_blocks = (depth - 2) // 6 + + self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_1 = nn.BatchNorm2d(16) + + self.inplanes = 16 + self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) + self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) + self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) + self.avgpool = nn.AvgPool2d(8) + self.out_dim = 64 * block.expansion + self.fc = nn.Linear(64*block.expansion, 10) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv_1_3x3(x) # [bs, 16, 32, 32] + x = F.relu(self.bn_1(x), inplace=True) + + x_1 = self.stage_1(x) # [bs, 16, 32, 32] + x_2 = self.stage_2(x_1) # [bs, 32, 16, 16] + x_3 = self.stage_3(x_2) # [bs, 64, 8, 8] + + pooled = self.avgpool(x_3) # [bs, 64, 1, 1] + features = pooled.view(pooled.size(0), -1) # [bs, 64] + + return { + 'fmaps': [x_1, x_2, x_3], + 'features': features + } + + @property + def last_conv(self): + return self.stage_3[-1].conv_b + + +def resnet20mnist(): + """Constructs a ResNet-20 model for MNIST.""" + model = CifarResNet(ResNetBasicblock, 20, 1) + return model + + +def resnet32mnist(): + """Constructs a ResNet-32 model for MNIST.""" + model = CifarResNet(ResNetBasicblock, 32, 1) + return model + + +def resnet20(): + """Constructs a ResNet-20 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 20) + return model + + +def resnet32(): + """Constructs a ResNet-32 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 32) + return model + + +def resnet44(): + """Constructs a ResNet-44 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 44) + return model + + +def resnet56(): + """Constructs a ResNet-56 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 56) + return model + + +def resnet110(): + """Constructs a ResNet-110 model for CIFAR-10.""" + model = CifarResNet(ResNetBasicblock, 110) + return model diff --git a/convs/linears.py b/convs/linears.py new file mode 100644 index 0000000..514cb7c --- /dev/null +++ b/convs/linears.py @@ -0,0 +1,317 @@ +''' +Reference: +https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py +''' +import math +import torch +from torch import nn +from torch.nn import functional as F +from timm.models.layers.weight_init import trunc_normal_ +from timm.models.layers import Mlp +from copy import deepcopy + +class MyContinualClassifier(nn.Module): + def __init__(self, embed_dim, nb_old_classes, nb_new_classes): + super().__init__() + + self.embed_dim = embed_dim + self.nb_old_classes = nb_old_classes + heads = [] + if nb_old_classes>0: + heads.append(nn.Linear(embed_dim, nb_old_classes, bias=True)) + self.old_head = nn.Linear(embed_dim, nb_old_classes, bias=True) + heads.append(nn.Linear(embed_dim, nb_new_classes, bias=True)) + self.heads = nn.ModuleList(heads) + self.aux_head = nn.Linear(embed_dim, 1, bias=True) + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, nonlinearity='linear') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, with_aux=False): + assert len(x.size())==2 + out = [] + for ti in range(len(self.heads)): + out.append(self.heads[ti](x)) + out = {'logits': torch.cat(out, dim=1)} + if len(self.heads)>1: + out['old_logits'] = self.old_head(x) + if with_aux: + out['aux_logits'] = self.aux_head(x) + return out + +class MlpHead(nn.Module): + def __init__(self, dim, nb_classes, mlp_ratio=3., drop=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self._fc = nn.Linear(dim, nb_classes, bias=True) + + def forward(self, x): + x = x + self.mlp(self.norm(x)) + x = self._fc(x) + return x + +class TaskEmbed(nn.Module): + def __init__(self, embed_dim): + super().__init__() + + self.task_token = nn.Parameter(torch.zeros(1, embed_dim)) + trunc_normal_(self.task_token, std=.02) + self.merge_fc = nn.Linear(2*embed_dim, embed_dim) + trunc_normal_(self.merge_fc.weight, std=.02) + + def forward(self, x): + x = F.gelu(self.merge_fc(torch.cat([x, self.task_token.repeat(x.size(0), 1)], 1)))+x + return x + + +class SimpleContinualLinear(nn.Module): + def __init__(self, embed_dim, nb_classes, feat_expand=False, with_norm=False, with_mlp=False, with_task_embed=False, with_preproj=False): + super().__init__() + + self.embed_dim = embed_dim + self.feat_expand = feat_expand + self.with_norm = with_norm + self.with_mlp = with_mlp + self.with_task_embed = with_task_embed + self.with_preproj = with_preproj + heads = [] + single_head = [] + if with_norm: + single_head.append(nn.LayerNorm(embed_dim)) + if with_task_embed: + single_head.append(TaskEmbed(embed_dim)) + + single_head.append(nn.Linear(embed_dim, nb_classes, bias=True)) + head = nn.Sequential(*single_head) + + if with_mlp: + head = MlpHead(embed_dim, nb_classes) + heads.append(head) + self.heads = nn.ModuleList(heads) + if self.with_preproj: + self.preproj = nn.Sequential(*[nn.Linear(embed_dim, embed_dim, bias=True), nn.GELU()]) + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if self.with_preproj: + for p in self.preproj.parameters(): + p.requires_grad=False + + + def backup(self): + self.old_state_dict = deepcopy(self.state_dict()) + + def recall(self): + self.load_state_dict(self.old_state_dict) + + + def update(self, nb_classes, freeze_old=True): + single_head = [] + if self.with_norm: + single_head.append(nn.LayerNorm(self.embed_dim)) + if self.with_task_embed: + single_head.append(TaskEmbed(self.embed_dim)) + + _fc = nn.Linear(self.embed_dim, nb_classes, bias=True) + trunc_normal_(_fc.weight, std=.02) + nn.init.constant_(_fc.bias, 0) + single_head.append(_fc) + new_head = nn.Sequential(*single_head) + + if self.with_mlp: + head = MlpHead(self.embed_dim, nb_classes) + for m in head.modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + if freeze_old: + for p in self.heads.parameters(): + p.requires_grad=False + + if self.with_preproj: + for p in self.preproj.parameters(): + p.requires_grad=False + + self.heads.append(new_head) + + def forward(self, x): + #assert len(x.size())==2 + if self.with_preproj: + x = self.preproj(x) + out = [] + for ti in range(len(self.heads)): + fc_inp = x[ti] if self.feat_expand else x + out.append(self.heads[ti](fc_inp)) + out = {'logits': torch.cat(out, dim=1)} + return out + +class SimpleLinear(nn.Module): + ''' + Reference: + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py + ''' + def __init__(self, in_features, out_features, bias=True, init_method='kaiming'): + super(SimpleLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters(init_method=init_method) + + def reset_parameters(self, init_method='kaiming'): + if init_method=='kaiming': + nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') + else: + trunc_normal_(self.weight, std=.02) + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, input): + return {'logits': F.linear(input, self.weight, self.bias)} + + +class CosineLinear(nn.Module): + def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True): + super(CosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features * nb_proxy + self.nb_proxy = nb_proxy + self.to_reduce = to_reduce + self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + else: + self.register_parameter('sigma', None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.sigma is not None: + self.sigma.data.fill_(1) + + def forward(self, input): + out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) + + if self.to_reduce: + # Reduce_proxy + out = reduce_proxies(out, self.nb_proxy) + + if self.sigma is not None: + out = self.sigma * out + + return {'logits': out} + + +class SplitCosineLinear(nn.Module): + def __init__(self, in_features, out_features1, out_features2, nb_proxy=1, sigma=True): + super(SplitCosineLinear, self).__init__() + self.in_features = in_features + self.out_features = (out_features1 + out_features2) * nb_proxy + self.nb_proxy = nb_proxy + self.fc1 = CosineLinear(in_features, out_features1, nb_proxy, False, False) + self.fc2 = CosineLinear(in_features, out_features2, nb_proxy, False, False) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + self.sigma.data.fill_(1) + else: + self.register_parameter('sigma', None) + + def forward(self, x): + out1 = self.fc1(x) + out2 = self.fc2(x) + + out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel + + # Reduce_proxy + out = reduce_proxies(out, self.nb_proxy) + + if self.sigma is not None: + out = self.sigma * out + + return { + 'old_scores': reduce_proxies(out1['logits'], self.nb_proxy), + 'new_scores': reduce_proxies(out2['logits'], self.nb_proxy), + 'logits': out + } + + +def reduce_proxies(out, nb_proxy): + if nb_proxy == 1: + return out + bs = out.shape[0] + nb_classes = out.shape[1] / nb_proxy + assert nb_classes.is_integer(), 'Shape error' + nb_classes = int(nb_classes) + + simi_per_class = out.view(bs, nb_classes, nb_proxy) + attentions = F.softmax(simi_per_class, dim=-1) + + return (attentions * simi_per_class).sum(-1) + + +''' +class CosineLinear(nn.Module): + def __init__(self, in_features, out_features, sigma=True): + super(CosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + else: + self.register_parameter('sigma', None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.sigma is not None: + self.sigma.data.fill_(1) + + def forward(self, input): + out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) + if self.sigma is not None: + out = self.sigma * out + return {'logits': out} + + +class SplitCosineLinear(nn.Module): + def __init__(self, in_features, out_features1, out_features2, sigma=True): + super(SplitCosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features1 + out_features2 + self.fc1 = CosineLinear(in_features, out_features1, False) + self.fc2 = CosineLinear(in_features, out_features2, False) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + self.sigma.data.fill_(1) + else: + self.register_parameter('sigma', None) + + def forward(self, x): + out1 = self.fc1(x) + out2 = self.fc2(x) + + out = torch.cat((out1['logits'], out2['logits']), dim=1) # concatenate along the channel + if self.sigma is not None: + out = self.sigma * out + + return { + 'old_scores': out1['logits'], + 'new_scores': out2['logits'], + 'logits': out + } +''' diff --git a/convs/resnet.py b/convs/resnet.py new file mode 100644 index 0000000..23af80a --- /dev/null +++ b/convs/resnet.py @@ -0,0 +1,364 @@ +''' +Reference: +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +''' +import torch +import torch.nn as nn +# from torchvision.models.utils import load_state_dict_from_url + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None, no_last_relu=False): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + self.no_last_relu = no_last_relu + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + if not self.no_last_relu: + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None, no_last_relu=False): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.no_last_relu = no_last_relu + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + if not self.no_last_relu: + out = self.relu(out) + + return out + + + + +# 修改Resnet的实现。 +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, cifar=False, no_last_relu=False): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.cifar = cifar + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + if self.cifar: + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + else: + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Removed in _forward_impl for cifar + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2], no_last_relu=no_last_relu) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.out_dim = 512 * block.expansion + # self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False, no_last_relu=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for bid in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer, no_last_relu=no_last_relu if bid==blocks-1 else False)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) # [bs, 64, 32, 32] + x = self.bn1(x) + x = self.relu(x) + if not self.cifar: + x = self.maxpool(x) + + x_1 = self.layer1(x) # [bs, 128, 32, 32] + x_2 = self.layer2(x_1) # [bs, 256, 16, 16] + x_3 = self.layer3(x_2) # [bs, 512, 8, 8] + x_4 = self.layer4(x_3) # [bs, 512, 4, 4] + + pooled = self.avgpool(x_4) # [bs, 512, 1, 1] + features = torch.flatten(pooled, 1) # [bs, 512] + # x = self.fc(x) + + return { + 'fmaps': [x_1, x_2, x_3, x_4], + 'features': features + } + + def forward(self, x): + return self._forward_impl(x) + + @property + def last_conv(self): + if hasattr(self.layer4[-1], 'conv3'): + return self.layer4[-1].conv3 + else: + return self.layer4[-1].conv2 + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) diff --git a/convs/vits.py b/convs/vits.py new file mode 100644 index 0000000..56ffbcc --- /dev/null +++ b/convs/vits.py @@ -0,0 +1,666 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020, Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv +from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from timm.models.registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (weights from official Google JAX impl) + 'vit_tiny_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_tiny_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_base_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + ), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + + 'vit_huge_patch14_224': _cfg(url=''), + 'vit_giant_patch14_224': _cfg(url=''), + 'vit_gigantic_patch14_224': _cfg(url=''), + + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + #url='./B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_large_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), + 'vit_huge_patch14_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', + hf_hub='timm/vit_huge_patch14_224_in21k', + num_classes=21843), + + # SAM trained models (https://arxiv.org/abs/2106.01548) + 'vit_base_patch32_sam_224': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), + 'vit_base_patch16_sam_224': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), + + # deit models (FB weights) + 'deit_tiny_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), + 'deit_tiny_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_small_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, + classifier=('head', 'head_dist')), + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + 'vit_base_patch16_224_miil': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' + '/vit_base_patch16_224_1k_miil_84_4.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), +} + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, + act_layer=None, weight_init='', with_adapter=False, global_pool=False): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.out_dim = embed_dim + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.with_adapter = with_adapter + self.global_pool = global_pool + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + if self.with_adapter: + self.adp_layers = [] + for adp_i in range(4): + self.adp_layers.append(self.get_adapter(embed_dim)) + self.adp_layers = nn.ModuleList(self.adp_layers) + self.adp_norm = nn.LayerNorm(embed_dim) + self.extra_blocks = nn.ModuleList([]) + self.init_weights(weight_init) + if self.with_adapter: + for adp_i in range(4): + nn.init.constant_(self.adp_layers[adp_i][-2].bias, -2.19) + + def get_adapter(self, embed_dim): + return nn.Sequential( + nn.Linear(embed_dim, embed_dim*3, bias=False), + nn.LayerNorm(embed_dim*3), + nn.GELU(), + nn.Linear(embed_dim*3, embed_dim, bias=False), + nn.LayerNorm(embed_dim), + nn.GELU(), + nn.Linear(embed_dim, embed_dim, bias=True), + nn.Sigmoid() + ) + + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) + if mode.startswith('jax'): + # leave cls token as zeros to match jax impl + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) + else: + trunc_normal_(self.cls_token, std=.02) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def get_classifier(self): + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, prompt=None, layer_feat=False): + img = x + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + prompt_length=0 + if self.dist_token is None and prompt is None: + x = torch.cat((cls_token, x), dim=1) + elif prompt is not None: + x = torch.cat((prompt, cls_token, x), dim=1) + prompt_length = prompt.size(1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x[:, prompt_length:] = self.pos_drop(x[:, prompt_length:] + self.pos_embed) + # x = self.blocks(x) + feats = [] + feats_l = [] + for b_id, block in enumerate(self.blocks): + x = block(x) + if self.with_adapter and (b_id+1) % (len(self.blocks)//4)==0: + feats.append(x) + if layer_feat: + feats_l.append(x) + if b_id == len(self.blocks)-2: + penultimate_feat = x.clone() + + if layer_feat: + return feats_l + + if len(self.extra_blocks)>0: + assert not self.with_adapter + outs = [self.norm(x)[:, 0]] + for extra_block in self.extra_blocks: + outs.append(extra_block(penultimate_feat)[:, 0]) + return outs + + if self.with_adapter and self.training: + adp_inp = feats[-1][:, 0].detach() + masks = [] + for adp_i, adp_layer in enumerate(self.adp_layers): + m_ = adp_layer(adp_inp) + #if adp_i==0: + # m_ = m_.mean(1) + # m_ = torch.sigmoid(m_) + adp_inp = m_ * feats[adp_i][:, 0] + feats[adp_i][:, 0].detach() + masks.append(m_) + return adp_inp, torch.cat(masks, dim=1) + #return self.adp_norm(adp_inp.unsqueeze(1)).squeeze(1) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + return self.norm(x) + + x = self.norm(x) + if self.dist_token is None: + if prompt is not None: + return x[:, :prompt_length].mean(dim=1) + return self.pre_logits(x[:, 0]) + else: + return x[:, 0] # , x[:, 1] + + def forward(self, x, prompt=None, layer_feat=False): + x = self.forward_features(x, prompt, layer_feat) + if self.with_adapter and self.training: + x = {'masks': x[1], 'features': x[0]} + else: + x = {'features': x} + #if self.head_dist is not None: + # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + # if self.training and not torch.jit.is_scripting(): + # # during inference, return the average of both classifier predictions + # return x, x_dist + # else: + # return (x + x_dist) / 2 + #else: + # x = self.head(x) + return x + + +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + model = build_model_with_cfg( + VisionTransformer, variant, pretrained, + default_cfg=default_cfg, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in default_cfg['url'], + **kwargs) + return model + + + +@register_model +def vit_base_patch16_224_in21k(pretrained=False, adapter=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, with_adapter=adapter, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + del model.head + del model.norm + model.norm = nn.LayerNorm(768) + return model + +@register_model +def vit_base_patch16_224_mocov3(pretrained=False, adapter=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, with_adapter=adapter, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=False, **model_kwargs) + del model.head + ckpt = torch.load('mocov3-vit-base-300ep.pth', map_location='cpu')['model'] + state_dict = model.state_dict() + state_dict.update(ckpt) + model.load_state_dict(state_dict) + del model.norm + model.norm = nn.LayerNorm(768) + return model + + diff --git a/evaluator.py b/evaluator.py new file mode 100644 index 0000000..8aaea98 --- /dev/null +++ b/evaluator.py @@ -0,0 +1,103 @@ +import sys +import logging +import copy +import torch +from utils import factory +from utils.data_manager import DataManager +from utils.toolkit import count_parameters +import os +import numpy as np + + +def test(args): + seed_list = copy.deepcopy(args['seed']) + device = copy.deepcopy(args['device']) + + for seed in seed_list: + args['seed'] = seed + args['device'] = device + _test(args) + + +def _test(args): + logfilename = 'logs/{}/{}_test_{}_{}_{}_{}_{}_{}'.format(args['model_name'], args['prefix'], args['seed'], args['model_name'], args['convnet_type'], + args['dataset'], args['init_cls'], args['increment']) + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(filename)s] => %(message)s', + handlers=[ + logging.FileHandler(filename=logfilename + '.log'), + logging.StreamHandler(sys.stdout) + ] + ) + + _set_random() + _set_device(args) + print_args(args) + data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment']) + model = factory.get_model(args['model_name'], args) + + cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []} + for task in range(data_manager.nb_tasks): + logging.info('All params: {}'.format(count_parameters(model._network))) + # logging.info('Trainable params: {}'.format(count_parameters(model._network, True))) + # model.incremental_train(data_manager) + model.incremental_update(data_manager) + cnn_accy, nme_accy = model.eval_task() + model.after_task() + + if nme_accy is not None: + logging.info('CNN: {}'.format(cnn_accy['grouped'])) + logging.info('NME: {}'.format(nme_accy['grouped'])) + + cnn_curve['top1'].append(cnn_accy['top1']) + cnn_curve['top5'].append(cnn_accy['top5']) + + nme_curve['top1'].append(nme_accy['top1']) + nme_curve['top5'].append(nme_accy['top5']) + + logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) + logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean())) + if 'task_acc' in cnn_accy.keys(): + logging.info('Task: {}'.format(cnn_accy['task_acc'])) + logging.info('CNN top5 curve: {}'.format(cnn_curve['top5'])) + logging.info('NME top1 curve: {}'.format(nme_curve['top1'])) + logging.info('NME top5 curve: {}\n'.format(nme_curve['top5'])) + else: + logging.info('No NME accuracy.') + logging.info('CNN: {}'.format(cnn_accy['grouped'])) + + cnn_curve['top1'].append(cnn_accy['top1']) + cnn_curve['top5'].append(cnn_accy['top5']) + + logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) + logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5'])) + + +def _set_device(args): + device_type = args['device'] + gpus = [] + + for device in device_type: + if device_type == -1: + device = torch.device('cpu') + else: + device = torch.device('cuda:{}'.format(device)) + + gpus.append(device) + + args['device'] = gpus + + +def _set_random(): + torch.manual_seed(1) + torch.cuda.manual_seed(1) + torch.cuda.manual_seed_all(1) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def print_args(args): + for key, value in args.items(): + logging.info('{}: {}'.format(key, value)) + diff --git a/exps/slca_cars.json b/exps/slca_cars.json new file mode 100644 index 0000000..995730e --- /dev/null +++ b/exps/slca_cars.json @@ -0,0 +1,19 @@ +{ + "prefix": "reproduce", + "dataset": "cars196_224", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "slca_cars", + "model_postfix": "50e", + "convnet_type": "vit-b-p16", + "device": ["0","1"], + "seed": [1993, 1996, 1997], + "epochs": 50, + "ca_epochs": 5, + "ca_with_logit_norm": 0.05, + "milestones": [40] +} diff --git a/exps/slca_cars_mocov3.json b/exps/slca_cars_mocov3.json new file mode 100644 index 0000000..ba2b828 --- /dev/null +++ b/exps/slca_cars_mocov3.json @@ -0,0 +1,19 @@ +{ + "prefix": "reproduce", + "dataset": "cars196_224", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "slca_cars_mocov3", + "model_postfix": "90e", + "convnet_type": "vit-b-p16-mocov3", + "device": ["0","1"], + "seed": [1993, 1996, 1997], + "epochs": 90, + "ca_epochs": 5, + "ca_with_logit_norm": 0.05, + "milestones": [80] +} diff --git a/exps/slca_cifar.json b/exps/slca_cifar.json new file mode 100644 index 0000000..fcc6db9 --- /dev/null +++ b/exps/slca_cifar.json @@ -0,0 +1,19 @@ +{ + "prefix": "reproduce", + "dataset": "cifar100_224", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 10, + "increment": 10, + "model_name": "slca_cifar", + "model_postfix": "20e", + "convnet_type": "vit-b-p16", + "device": ["0","1"], + "seed": [1993, 1996, 1997], + "epochs": 20, + "ca_epochs": 5, + "ca_with_logit_norm": 0.1, + "milestones": [18] +} diff --git a/exps/slca_cifar_mocov3.json b/exps/slca_cifar_mocov3.json new file mode 100644 index 0000000..3c2647b --- /dev/null +++ b/exps/slca_cifar_mocov3.json @@ -0,0 +1,19 @@ +{ + "prefix": "reproduce", + "dataset": "cifar100_224", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 10, + "increment": 10, + "model_name": "slca_cifar_mocov3", + "model_postfix": "90e", + "convnet_type": "vit-b-p16-mocov3", + "device": ["0","1"], + "seed": [1993, 1996, 1997], + "epochs": 90, + "ca_epochs": 5, + "ca_with_logit_norm": 0.1, + "milestones": [80] +} diff --git a/exps/slca_cub.json b/exps/slca_cub.json new file mode 100644 index 0000000..b2b986c --- /dev/null +++ b/exps/slca_cub.json @@ -0,0 +1,19 @@ +{ + "prefix": "reproduce", + "dataset": "cub200_224", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "slca_cub", + "model_postfix": "50e", + "convnet_type": "vit-b-p16", + "device": ["0","1"], + "seed": [1993, 1996, 1997], + "epochs": 50, + "ca_epochs": 5, + "ca_with_logit_norm": 0.1, + "milestones": [40] +} diff --git a/exps/slca_cub_mocov3.json b/exps/slca_cub_mocov3.json new file mode 100644 index 0000000..9a38e74 --- /dev/null +++ b/exps/slca_cub_mocov3.json @@ -0,0 +1,19 @@ +{ + "prefix": "reproduce", + "dataset": "cub200_224", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "slca_cub_mocov3", + "model_postfix": "90e", + "convnet_type": "vit-b-p16-mocov3", + "device": ["0","1"], + "seed": [1993, 1996, 1997], + "epochs": 90, + "ca_epochs": 5, + "ca_with_logit_norm": 0.1, + "milestones": [80] +} diff --git a/exps/slca_imgnetr.json b/exps/slca_imgnetr.json new file mode 100644 index 0000000..e07c0f2 --- /dev/null +++ b/exps/slca_imgnetr.json @@ -0,0 +1,19 @@ +{ + "prefix": "reproduce", + "dataset": "imagenet-r", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "slca_imgnetr", + "model_postfix": "50e", + "convnet_type": "vit-b-p16", + "device": ["0","1"], + "seed": [1993, 1996, 1997], + "epochs": 50, + "ca_epochs": 5, + "ca_with_logit_norm": 0.1, + "milestones": [40] +} diff --git a/exps/slca_imgnetr_mocov3.json b/exps/slca_imgnetr_mocov3.json new file mode 100644 index 0000000..7a22cf3 --- /dev/null +++ b/exps/slca_imgnetr_mocov3.json @@ -0,0 +1,19 @@ +{ + "prefix": "reproduce", + "dataset": "imagenet-r", + "memory_size": 0, + "memory_per_class": 0, + "fixed_memory": false, + "shuffle": true, + "init_cls": 20, + "increment": 20, + "model_name": "slca_imgnetr_mocov3", + "model_postfix": "90e", + "convnet_type": "vit-b-p16-mocov3", + "device": ["0","1"], + "seed": [1993, 1996, 1997], + "epochs": 90, + "ca_epochs": 5, + "ca_with_logit_norm": 0.1, + "milestones": [80] +} diff --git a/main.py b/main.py new file mode 100644 index 0000000..e81c032 --- /dev/null +++ b/main.py @@ -0,0 +1,33 @@ +import json +import argparse +from trainer import train +from evaluator import test + +def main(): + args = setup_parser().parse_args() + param = load_json(args.config) + args = vars(args) # Converting argparse Namespace to a dict. + args.update(param) # Add parameters from json + if args['test_only']: + test(args) + else: + train(args) + + +def load_json(settings_path): + with open(settings_path) as data_file: + param = json.load(data_file) + + return param + + +def setup_parser(): + parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') + parser.add_argument('--config', type=str, default='./exps/finetune.json', + help='Json file of settings.') + parser.add_argument('--test_only', action='store_true') + return parser + + +if __name__ == '__main__': + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000..194f605 --- /dev/null +++ b/models/base.py @@ -0,0 +1,403 @@ +import copy +import logging +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader +from utils.toolkit import tensor2numpy, accuracy +from scipy.spatial.distance import cdist + +EPSILON = 1e-8 +batch_size = 64 + + +class BaseLearner(object): + def __init__(self, args): + self._cur_task = -1 + self._known_classes = 0 + self._total_classes = 0 + self._network = None + self._old_network = None + self._data_memory, self._targets_memory = np.array([]), np.array([]) + self.topk = 5 + + self._memory_size = args['memory_size'] + self._memory_per_class = args['memory_per_class'] + self._fixed_memory = args['fixed_memory'] + self._device = args['device'][0] + self._multiple_gpus = args['device'] + + @property + def exemplar_size(self): + assert len(self._data_memory) == len(self._targets_memory), 'Exemplar size error.' + return len(self._targets_memory) + + @property + def samples_per_class(self): + if self._fixed_memory: + return self._memory_per_class + else: + assert self._total_classes != 0, 'Total classes is 0' + return (self._memory_size // self._total_classes) + + @property + def feature_dim(self): + if isinstance(self._network, nn.DataParallel): + return self._network.module.feature_dim + else: + return self._network.feature_dim + + def build_rehearsal_memory(self, data_manager, per_class, mode='icarl'): + if self._fixed_memory: + self._construct_exemplar_unified(data_manager, per_class) + else: + self._reduce_exemplar(data_manager, per_class) + self._construct_exemplar(data_manager, per_class, mode=mode) + + def save_checkpoint(self, filename, head_only=False): + if hasattr(self._network, 'module'): + to_save = self._network.module + else: + to_save = self._network + + if head_only: + to_save = to_save.fc + + save_dict = { + 'tasks': self._cur_task, + 'model_state_dict': to_save.state_dict(), + } + torch.save(save_dict, '{}_{}.pth'.format(filename, self._cur_task)) + + def after_task(self): + pass + + def _evaluate(self, y_pred, y_true): + ret = {} + grouped = accuracy(y_pred.T[0], y_true, self._known_classes) + ret['grouped'] = grouped + ret['top1'] = grouped['total'] + ret['top{}'.format(5)] = np.around((y_pred.T == np.tile(y_true, (self.topk, 1))).sum()*100/len(y_true), + decimals=2) + + return ret + + def eval_task(self): + y_pred, y_true = self._eval_cnn(self.test_loader) + cnn_accy = self._evaluate(y_pred, y_true) + + if hasattr(self, '_class_means') and False: # TODO + y_pred, y_true = self._eval_nme(self.test_loader, self._class_means) + nme_accy = self._evaluate(y_pred, y_true) + else: + nme_accy = None + + return cnn_accy, nme_accy + + def incremental_train(self): + pass + + def _train(self): + pass + + def _get_memory(self): + if len(self._data_memory) == 0: + return None + else: + return (self._data_memory, self._targets_memory) + + + def _inner_eval(self, model, loader): + model.eval() + y_pred, y_true = [], [] + for _, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = model(inputs)['logits'] + predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] # [bs, topk] + y_pred.append(predicts.cpu().numpy()) + y_true.append(targets.cpu().numpy()) + + y_pred, y_true = np.concatenate(y_pred), np.concatenate(y_true) # [N, topk] + + cnn_accy = self._evaluate(y_pred, y_true) + return cnn_accy + + def _compute_accuracy(self, model, loader): + model.eval() + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = model(inputs)['logits'] + predicts = torch.max(outputs, dim=1)[1] + correct += (predicts.cpu() == targets).sum() + total += len(targets) + + return np.around(tensor2numpy(correct)*100 / total, decimals=2) + + def _eval_cnn(self, loader): + self._network.eval() + y_pred, y_true = [], [] + for _, (_, inputs, targets) in enumerate(loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + outputs = self._network(inputs)['logits'] + predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1] # [bs, topk] + y_pred.append(predicts.cpu().numpy()) + y_true.append(targets.cpu().numpy()) + + return np.concatenate(y_pred), np.concatenate(y_true) # [N, topk] + + def _eval_nme(self, loader, class_means): + self._network.eval() + vectors, y_true = self._extract_vectors(loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + + norm_means = class_means / np.linalg.norm(class_means) + dists = cdist(norm_means, vectors, 'sqeuclidean') # [nb_classes, N] + scores = dists.T # [N, nb_classes], choose the one with the smallest distance + + return np.argsort(scores, axis=1)[:, :self.topk], y_true # [N, topk] + + def _extract_vectors(self, loader): + self._network.eval() + vectors, targets = [], [] + for _, _inputs, _targets in loader: + _targets = _targets.numpy() + if isinstance(self._network, nn.DataParallel): + _vectors = tensor2numpy(self._network.module.extract_vector(_inputs.to(self._device))) + else: + _vectors = tensor2numpy(self._network.extract_vector(_inputs.to(self._device))) + + vectors.append(_vectors) + targets.append(_targets) + + return np.concatenate(vectors), np.concatenate(targets) + + def _extract_vectors_aug(self, loader, repeat=2): + self._network.eval() + vectors, targets = [], [] + for _ in range(repeat): + for _, _inputs, _targets in loader: + _targets = _targets.numpy() + with torch.no_grad(): + if isinstance(self._network, nn.DataParallel): + _vectors = tensor2numpy(self._network.module.extract_vector(_inputs.to(self._device))) + else: + _vectors = tensor2numpy(self._network.extract_vector(_inputs.to(self._device))) + + vectors.append(_vectors) + targets.append(_targets) + + return np.concatenate(vectors), np.concatenate(targets) + + def _reduce_exemplar(self, data_manager, m): + logging.info('Reducing exemplars...({} per classes)'.format(m)) + dummy_data, dummy_targets = copy.deepcopy(self._data_memory), copy.deepcopy(self._targets_memory) + self._class_means = np.zeros((self._total_classes, self.feature_dim)) + self._data_memory, self._targets_memory = np.array([]), np.array([]) + + for class_idx in range(self._known_classes): + mask = np.where(dummy_targets == class_idx)[0] + dd, dt = dummy_data[mask][:m], dummy_targets[mask][:m] + self._data_memory = np.concatenate((self._data_memory, dd)) if len(self._data_memory) != 0 else dd + self._targets_memory = np.concatenate((self._targets_memory, dt)) if len(self._targets_memory) != 0 else dt + + # Exemplar mean + idx_dataset = data_manager.get_dataset([], source='train', mode='test', appendent=(dd, dt)) + idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + self._class_means[class_idx, :] = mean + + + + def _compute_class_mean(self, data_manager, check_diff=False, oracle=False): + if hasattr(self, '_class_means') and self._class_means is not None and not check_diff: + ori_classes = self._class_means.shape[0] + assert ori_classes==self._known_classes + new_class_means = np.zeros((self._total_classes, self.feature_dim)) + new_class_means[:self._known_classes] = self._class_means + self._class_means = new_class_means + # new_class_cov = np.zeros((self._total_classes, self.feature_dim, self.feature_dim)) + new_class_cov = torch.zeros((self._total_classes, self.feature_dim, self.feature_dim)) + new_class_cov[:self._known_classes] = self._class_covs + self._class_covs = new_class_cov + elif not check_diff: + self._class_means = np.zeros((self._total_classes, self.feature_dim)) + # self._class_covs = np.zeros((self._total_classes, self.feature_dim, self.feature_dim)) + self._class_covs = torch.zeros((self._total_classes, self.feature_dim, self.feature_dim)) + + # self._class_covs = [] + + if check_diff: + for class_idx in range(0, self._known_classes): + data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + # vectors, _ = self._extract_vectors_aug(idx_loader) + vectors, _ = self._extract_vectors(idx_loader) + class_mean = np.mean(vectors, axis=0) + # class_cov = np.cov(vectors.T) + class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T) + if check_diff: + log_info = "cls {} sim: {}".format(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0)).item()) + logging.info(log_info) + np.save('task_{}_cls_{}_mean.npy'.format(self._cur_task, class_idx), class_mean) + # print(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0))) + + if oracle: + for class_idx in range(0, self._known_classes): + data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + + # vectors = np.concatenate([vectors_aug, vectors]) + + class_mean = np.mean(vectors, axis=0) + # class_cov = np.cov(vectors.T) + class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T)+torch.eye(class_mean.shape[-1])*1e-5 + self._class_means[class_idx, :] = class_mean + self._class_covs[class_idx, ...] = class_cov + + for class_idx in range(self._known_classes, self._total_classes): + # data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + # mode='train', ret_data=True) + # idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + # vectors_aug, _ = self._extract_vectors_aug(idx_loader) + + data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + + # vectors = np.concatenate([vectors_aug, vectors]) + + class_mean = np.mean(vectors, axis=0) + # class_cov = np.cov(vectors.T) + class_cov = torch.cov(torch.tensor(vectors, dtype=torch.float64).T)+torch.eye(class_mean.shape[-1])*1e-4 + if check_diff: + log_info = "cls {} sim: {}".format(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0)).item()) + logging.info(log_info) + np.save('task_{}_cls_{}_mean.npy'.format(self._cur_task, class_idx), class_mean) + np.save('task_{}_cls_{}_mean_beforetrain.npy'.format(self._cur_task, class_idx), self._class_means[class_idx, :]) + # print(class_idx, torch.cosine_similarity(torch.tensor(self._class_means[class_idx, :]).unsqueeze(0), torch.tensor(class_mean).unsqueeze(0))) + self._class_means[class_idx, :] = class_mean + self._class_covs[class_idx, ...] = class_cov + # self._class_covs.append(class_cov) + + + def _construct_exemplar(self, data_manager, m, mode='icarl'): + logging.info('Constructing exemplars...({} per classes)'.format(m)) + for class_idx in range(self._known_classes, self._total_classes): + data, targets, idx_dataset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + if mode == 'icarl': + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + m = min(m, vectors.shape[0]) + # Select + selected_exemplars = [] + exemplar_vectors = [] # [n, feature_dim] + for k in range(1, m+1): + S = np.sum(exemplar_vectors, axis=0) # [feature_dim] sum of selected exemplars vectors + mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + selected_exemplars.append(np.array(data[i])) # New object to avoid passing by inference + exemplar_vectors.append(np.array(vectors[i])) # New object to avoid passing by inference + + vectors = np.delete(vectors, i, axis=0) # Remove it to avoid duplicative selection + data = np.delete(data, i, axis=0) # Remove it to avoid duplicative selection + # uniques = np.unique(selected_exemplars, axis=0) + # print('Unique elements: {}'.format(len(uniques))) + selected_exemplars = np.array(selected_exemplars) + exemplar_targets = np.full(m, class_idx) + else: + selected_index = np.random.choice(len(data), (min(m, len(data)),), replace=False) + selected_exemplars = data[selected_index] + exemplar_targets = np.full(min(m, len(data)), class_idx) + self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ + else selected_exemplars + self._targets_memory = np.concatenate((self._targets_memory, exemplar_targets)) if \ + len(self._targets_memory) != 0 else exemplar_targets + + # Exemplar mean + idx_dataset = data_manager.get_dataset([], source='train', mode='test', + appendent=(selected_exemplars, exemplar_targets)) + idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(idx_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + self._class_means[class_idx, :] = mean + + def _construct_exemplar_unified(self, data_manager, m): + logging.info('Constructing exemplars for new classes...({} per classes)'.format(m)) + _class_means = np.zeros((self._total_classes, self.feature_dim)) + + # Calculate the means of old classes with newly trained network + for class_idx in range(self._known_classes): + mask = np.where(self._targets_memory == class_idx)[0] + class_data, class_targets = self._data_memory[mask], self._targets_memory[mask] + + class_dset = data_manager.get_dataset([], source='train', mode='test', + appendent=(class_data, class_targets)) + class_loader = DataLoader(class_dset, batch_size=batch_size, shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(class_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + _class_means[class_idx, :] = mean + + # Construct exemplars for new classes and calculate the means + for class_idx in range(self._known_classes, self._total_classes): + data, targets, class_dset = data_manager.get_dataset(np.arange(class_idx, class_idx+1), source='train', + mode='test', ret_data=True) + class_loader = DataLoader(class_dset, batch_size=batch_size, shuffle=False, num_workers=4) + + vectors, _ = self._extract_vectors(class_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + class_mean = np.mean(vectors, axis=0) + + # Select + selected_exemplars = [] + exemplar_vectors = [] + for k in range(1, m+1): + S = np.sum(exemplar_vectors, axis=0) # [feature_dim] sum of selected exemplars vectors + mu_p = (vectors + S) / k # [n, feature_dim] sum to all vectors + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + + selected_exemplars.append(np.array(data[i])) # New object to avoid passing by inference + exemplar_vectors.append(np.array(vectors[i])) # New object to avoid passing by inference + + vectors = np.delete(vectors, i, axis=0) # Remove it to avoid duplicative selection + data = np.delete(data, i, axis=0) # Remove it to avoid duplicative selection + + selected_exemplars = np.array(selected_exemplars) + exemplar_targets = np.full(m, class_idx) + self._data_memory = np.concatenate((self._data_memory, selected_exemplars)) if len(self._data_memory) != 0 \ + else selected_exemplars + self._targets_memory = np.concatenate((self._targets_memory, exemplar_targets)) if \ + len(self._targets_memory) != 0 else exemplar_targets + + # Exemplar mean + exemplar_dset = data_manager.get_dataset([], source='train', mode='test', + appendent=(selected_exemplars, exemplar_targets)) + exemplar_loader = DataLoader(exemplar_dset, batch_size=batch_size, shuffle=False, num_workers=4) + vectors, _ = self._extract_vectors(exemplar_loader) + vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T + mean = np.mean(vectors, axis=0) + mean = mean / np.linalg.norm(mean) + + _class_means[class_idx, :] = mean + + self._class_means = _class_means diff --git a/models/slca.py b/models/slca.py new file mode 100644 index 0000000..b32a7b9 --- /dev/null +++ b/models/slca.py @@ -0,0 +1,258 @@ +import logging +import numpy as np +import torch +from torch import nn +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from models.base import BaseLearner +from utils.inc_net import FinetuneIncrementalNet +from torchvision import transforms +from torch.distributions.multivariate_normal import MultivariateNormal +import random +from utils.toolkit import tensor2numpy, accuracy +import copy +import os + +epochs = 20 +lrate = 0.01 +milestones = [60,100,140] +lrate_decay = 0.1 +batch_size = 128 +split_ratio = 0.1 +T = 2 +weight_decay = 5e-4 +num_workers = 8 +ca_epochs = 5 + + +class SLCA(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = FinetuneIncrementalNet(args['convnet_type'], pretrained=True) + self.log_path = "logs/{}_{}".format(args['model_name'], args['model_postfix']) + self.model_prefix = args['prefix'] + if 'epochs' in args.keys(): + global epochs + epochs = args['epochs'] + if 'milestones' in args.keys(): + global milestones + milestones = args['milestones'] + if 'lr' in args.keys(): + global lrate + lrate = args['lr'] + print('set lr to ', lrate) + if 'bcb_lrscale' in args.keys(): + self.bcb_lrscale = args['bcb_lrscale'] + else: + self.bcb_lrscale = 1.0/100 + if self.bcb_lrscale == 0: + self.fix_bcb = True + else: + self.fix_bcb = False + print('fic_bcb', self.fix_bcb) + + + + if 'save_before_ca' in args.keys() and args['save_before_ca']: + self.save_before_ca = True + else: + self.save_before_ca = False + + if 'ca_epochs' in args.keys(): + global ca_epochs + ca_epochs = args['ca_epochs'] + + if 'ca_with_logit_norm' in args.keys() and args['ca_with_logit_norm']>0: + self.logit_norm = args['ca_with_logit_norm'] + else: + self.logit_norm = None + + self.run_id = args['run_id'] + self.seed = args['seed'] + self.task_sizes = [] + + def after_task(self): + self._known_classes = self._total_classes + logging.info('Exemplar size: {}'.format(self.exemplar_size)) + self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}'.format(self.seed), head_only=self.fix_bcb) + self._network.fc.recall() + + def incremental_train(self, data_manager): + self._cur_task += 1 + task_size = data_manager.get_task_size(self._cur_task) + self.task_sizes.append(task_size) + self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) + self.topk = self._total_classes if self._total_classes<5 else 5 + self._network.update_fc(data_manager.get_task_size(self._cur_task)) + logging.info('Learning on {}-{}'.format(self._known_classes, self._total_classes)) + + self._network.to(self._device) + + train_dset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), + source='train', mode='train', + appendent=[], with_raw=False) + test_dset = data_manager.get_dataset(np.arange(0, self._total_classes), source='test', mode='test') + dset_name = data_manager.dataset_name.lower() + + self.train_loader = DataLoader(train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + self.test_loader = DataLoader(test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + self._stage1_training(self.train_loader, self.test_loader) + + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + # CA + self._network.fc.backup() + if self.save_before_ca: + self.save_checkpoint(self.log_path+'/'+self.model_prefix+'_seed{}_before_ca'.format(self.seed), head_only=self.fix_bcb) + + self._compute_class_mean(data_manager, check_diff=False, oracle=False) + if self._cur_task>0 and ca_epochs>0: + self._stage2_compact_classifier(task_size) + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + + def _run(self, train_loader, test_loader, optimizer, scheduler): + run_epochs = epochs + for epoch in range(1, run_epochs+1): + self._network.train() + losses = 0. + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + + logits = self._network(inputs, bcb_no_grad=self.fix_bcb)['logits'] + cur_targets = torch.where(targets-self._known_classes>=0,targets-self._known_classes,-100) + loss = F.cross_entropy(logits[:, self._known_classes:], cur_targets) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + scheduler.step() + if epoch%5==0: + train_acc = self._compute_accuracy(self._network, train_loader) + test_acc = self._compute_accuracy(self._network, test_loader) + info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}'.format( + self._cur_task, epoch, epochs, losses/len(train_loader), train_acc, test_acc) + else: + info = 'Task {}, Epoch {}/{} => Loss {:.3f}'.format( + self._cur_task, epoch, epochs, losses/len(train_loader)) + logging.info(info) + + def _stage1_training(self, train_loader, test_loader): + ''' + if self._cur_task == 0: + loaded_dict = torch.load('./dict_0.pkl') + self._network.load_state_dict(loaded_dict['model_state_dict']) + self._network.to(self._device) + return + ''' + base_params = self._network.convnet.parameters() + base_fc_params = [p for p in self._network.fc.parameters() if p.requires_grad==True] + head_scale = 1. if 'moco' in self.log_path else 1. + if not self.fix_bcb: + base_params = {'params': base_params, 'lr': lrate*self.bcb_lrscale, 'weight_decay': weight_decay} + base_fc_params = {'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay} + network_params = [base_params, base_fc_params] + else: + for p in base_params: + p.requires_grad = False + network_params = [{'params': base_fc_params, 'lr': lrate*head_scale, 'weight_decay': weight_decay}] + optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=lrate_decay) + + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + + self._run(train_loader, test_loader, optimizer, scheduler) + + + def _stage2_compact_classifier(self, task_size): + for p in self._network.fc.parameters(): + p.requires_grad=True + + run_epochs = ca_epochs + crct_num = self._total_classes + param_list = [p for p in self._network.fc.parameters() if p.requires_grad] + network_params = [{'params': param_list, 'lr': lrate, + 'weight_decay': weight_decay}] + optimizer = optim.SGD(network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay) + # scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[4], gamma=lrate_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=run_epochs) + + self._network.to(self._device) + if len(self._multiple_gpus) > 1: + self._network = nn.DataParallel(self._network, self._multiple_gpus) + + self._network.eval() + for epoch in range(run_epochs): + losses = 0. + + sampled_data = [] + sampled_label = [] + num_sampled_pcls = 256 + + for c_id in range(crct_num): + t_id = c_id//task_size + decay = (t_id+1)/(self._cur_task+1)*0.1 + cls_mean = torch.tensor(self._class_means[c_id], dtype=torch.float64).to(self._device)*(0.9+decay) # torch.from_numpy(self._class_means[c_id]).to(self._device) + cls_cov = self._class_covs[c_id].to(self._device) + + m = MultivariateNormal(cls_mean.float(), cls_cov.float()) + + sampled_data_single = m.sample(sample_shape=(num_sampled_pcls,)) + sampled_data.append(sampled_data_single) + sampled_label.extend([c_id]*num_sampled_pcls) + + sampled_data = torch.cat(sampled_data, dim=0).float().to(self._device) + sampled_label = torch.tensor(sampled_label).long().to(self._device) + + inputs = sampled_data + targets= sampled_label + + sf_indexes = torch.randperm(inputs.size(0)) + inputs = inputs[sf_indexes] + targets = targets[sf_indexes] + + + for _iter in range(crct_num): + inp = inputs[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls] + tgt = targets[_iter*num_sampled_pcls:(_iter+1)*num_sampled_pcls] + outputs = self._network(inp, bcb_no_grad=True, fc_only=True) + logits = outputs['logits'] + + if self.logit_norm is not None: + per_task_norm = [] + prev_t_size = 0 + cur_t_size = 0 + for _ti in range(self._cur_task+1): + cur_t_size += self.task_sizes[_ti] + temp_norm = torch.norm(logits[:, prev_t_size:cur_t_size], p=2, dim=-1, keepdim=True) + 1e-7 + per_task_norm.append(temp_norm) + prev_t_size += self.task_sizes[_ti] + per_task_norm = torch.cat(per_task_norm, dim=-1) + norms = per_task_norm.mean(dim=-1, keepdim=True) + + norms_all = torch.norm(logits[:, :crct_num], p=2, dim=-1, keepdim=True) + 1e-7 + decoupled_logits = torch.div(logits[:, :crct_num], norms) / self.logit_norm + loss = F.cross_entropy(decoupled_logits, tgt) + + else: + loss = F.cross_entropy(logits[:, :crct_num], tgt) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + scheduler.step() + test_acc = self._compute_accuracy(self._network, self.test_loader) + info = 'CA Task {} => Loss {:.3f}, Test_accy {:.3f}'.format( + self._cur_task, losses/self._total_classes, test_acc) + logging.info(info) + + diff --git a/split_car.py b/split_car.py new file mode 100644 index 0000000..3db6bdc --- /dev/null +++ b/split_car.py @@ -0,0 +1,21 @@ +import numpy as np +import os +import os.path as osp +import shutil +from tqdm import tqdm + +img_list = np.genfromtxt('mat2txt.txt', dtype=str) # [num, img, cls, istest] +class_mappings = np.genfromtxt('label_map.txt', dtype=str) +class_mappings = {a[0]: a[1] for a in class_mappings} + +for item in tqdm(img_list): + if bool(int(item[-1])): + cls_folder = osp.join('cars196', 'train', class_mappings[item[2]]) + if not os.path.exists(cls_folder): + os.mkdir(cls_folder) + shutil.copy(item[1], osp.join(cls_folder, item[1].split('/')[-1])) + else: + cls_folder = osp.join('cars196', 'val', class_mappings[item[2]]) + if not os.path.exists(cls_folder): + os.mkdir(cls_folder) + shutil.copy(item[1], osp.join(cls_folder, item[1].split('/')[-1])) diff --git a/split_cub.py b/split_cub.py new file mode 100644 index 0000000..944b484 --- /dev/null +++ b/split_cub.py @@ -0,0 +1,16 @@ +import numpy as np +import os +import os.path as osp +import shutil +from tqdm import tqdm + +train_val_list = np.genfromtxt('train_test_split.txt', dtype='str') +img_list = np.genfromtxt('images.txt', dtype='str') + +img_id_mapping = {a[0]: a[1] for a in img_list} +for img, is_train in tqdm(train_val_list): + if bool(int(is_train)): + # print(osp.join('CUB200', 'val', img_id_mapping[img])) + os.remove(osp.join('CUB200', 'val', img_id_mapping[img])) + else: + os.remove(osp.join('CUB200', 'train', img_id_mapping[img])) diff --git a/train_all.sh b/train_all.sh new file mode 100644 index 0000000..8768233 --- /dev/null +++ b/train_all.sh @@ -0,0 +1,10 @@ +CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=exps/slca_cifar.json +CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=exps/slca_imgnetr.json +CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=exps/slca_cub.json +CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=exps/slca_cars.json + +CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=exps/slca_cifar_mocov3.json +CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=exps/slca_imgnetr_mocov3.json +CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=exps/slca_cub_mocov3.json +CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=exps/slca_cars_mocov3.json + diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..d656e06 --- /dev/null +++ b/trainer.py @@ -0,0 +1,115 @@ +import sys +import logging +import copy +import torch +from utils import factory +from utils.data_manager import DataManager +from utils.toolkit import count_parameters +import os +import numpy as np + +def train(args): + seed_list = copy.deepcopy(args['seed']) + device = copy.deepcopy(args['device']) + + res_finals, res_avgs = [], [] + for run_id, seed in enumerate(seed_list): + args['seed'] = seed + args['run_id'] = run_id + args['device'] = device + res_final, res_avg = _train(args) + res_finals.append(res_final) + res_avgs.append(res_avg) + logging.info('final accs: {}'.format(res_finals)) + logging.info('avg accs: {}'.format(res_avgs)) + + + +def _train(args): + try: + os.mkdir("logs/{}_{}".format(args['model_name'], args['model_postfix'])) + except: + pass + logfilename = 'logs/{}_{}/{}_{}_{}_{}_{}_{}_{}'.format(args['model_name'], args['model_postfix'], args['prefix'], args['seed'], args['model_name'], args['convnet_type'], + args['dataset'], args['init_cls'], args['increment']) + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(filename)s] => %(message)s', + handlers=[ + logging.FileHandler(filename=logfilename + '.log'), + logging.StreamHandler(sys.stdout) + ] + ) + + _set_random() + _set_device(args) + print_args(args) + data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment']) + model = factory.get_model(args['model_name'], args) + + cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []} + for task in range(data_manager.nb_tasks): + logging.info('All params: {}'.format(count_parameters(model._network))) + logging.info('Trainable params: {}'.format(count_parameters(model._network, True))) + model.incremental_train(data_manager) + + cnn_accy, nme_accy = model.eval_task() + model.after_task() + + + if nme_accy is not None: + logging.info('CNN: {}'.format(cnn_accy['grouped'])) + logging.info('NME: {}'.format(nme_accy['grouped'])) + + cnn_curve['top1'].append(cnn_accy['top1']) + cnn_curve['top5'].append(cnn_accy['top5']) + + nme_curve['top1'].append(nme_accy['top1']) + nme_curve['top5'].append(nme_accy['top5']) + + logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) + logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean())) + if 'task_acc' in cnn_accy.keys(): + logging.info('Task: {}'.format(cnn_accy['task_acc'])) + logging.info('CNN top5 curve: {}'.format(cnn_curve['top5'])) + logging.info('NME top1 curve: {}'.format(nme_curve['top1'])) + logging.info('NME top5 curve: {}\n'.format(nme_curve['top5'])) + else: + logging.info('No NME accuracy.') + logging.info('CNN: {}'.format(cnn_accy['grouped'])) + + cnn_curve['top1'].append(cnn_accy['top1']) + cnn_curve['top5'].append(cnn_accy['top5']) + + logging.info('CNN top1 curve: {}'.format(cnn_curve['top1'])) + logging.info('CNN top1 avg: {}'.format(np.array(cnn_curve['top1']).mean())) + logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5'])) + + return (cnn_curve['top1'][-1], np.array(cnn_curve['top1']).mean()) + +def _set_device(args): + device_type = args['device'] + gpus = [] + + for device in device_type: + if device_type == -1: + device = torch.device('cpu') + else: + device = torch.device('cuda:{}'.format(device)) + + gpus.append(device) + + args['device'] = gpus + + +def _set_random(): + torch.manual_seed(1) + torch.cuda.manual_seed(1) + torch.cuda.manual_seed_all(1) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def print_args(args): + for key, value in args.items(): + logging.info('{}: {}'.format(key, value)) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/buffer.py b/utils/buffer.py new file mode 100644 index 0000000..da2f211 --- /dev/null +++ b/utils/buffer.py @@ -0,0 +1,225 @@ +# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from typing import Tuple +from torchvision import transforms +from copy import deepcopy + +def icarl_replay(self, dataset, val_set_split=0): + """ + Merge the replay buffer with the current task data. + Optionally split the replay buffer into a validation set. + + :param self: the model instance + :param dataset: the dataset + :param val_set_split: the fraction of the replay buffer to be used as validation set + """ + + if self.task > 0: + buff_val_mask = torch.rand(len(self.buffer)) < val_set_split + val_train_mask = torch.zeros(len(dataset.train_loader.dataset.data)).bool() + val_train_mask[torch.randperm(len(dataset.train_loader.dataset.data))[:buff_val_mask.sum()]] = True + + if val_set_split > 0: + self.val_loader = deepcopy(dataset.train_loader) + + data_concatenate = torch.cat if type(dataset.train_loader.dataset.data) == torch.Tensor else np.concatenate + need_aug = hasattr(dataset.train_loader.dataset, 'not_aug_transform') + if not need_aug: + refold_transform = lambda x: x.cpu() + else: + data_shape = len(dataset.train_loader.dataset.data[0].shape) + if data_shape == 3: + refold_transform = lambda x: (x.cpu()*255).permute([0, 2, 3, 1]).numpy().astype(np.uint8) + elif data_shape == 2: + refold_transform = lambda x: (x.cpu()*255).squeeze(1).type(torch.uint8) + + # REDUCE AND MERGE TRAINING SET + dataset.train_loader.dataset.targets = np.concatenate([ + dataset.train_loader.dataset.targets[~val_train_mask], + self.buffer.labels.cpu().numpy()[:len(self.buffer)][~buff_val_mask] + ]) + dataset.train_loader.dataset.data = data_concatenate([ + dataset.train_loader.dataset.data[~val_train_mask], + refold_transform((self.buffer.examples)[:len(self.buffer)][~buff_val_mask]) + ]) + + if val_set_split > 0: + # REDUCE AND MERGE VALIDATION SET + self.val_loader.dataset.targets = np.concatenate([ + self.val_loader.dataset.targets[val_train_mask], + self.buffer.labels.cpu().numpy()[:len(self.buffer)][buff_val_mask] + ]) + self.val_loader.dataset.data = data_concatenate([ + self.val_loader.dataset.data[val_train_mask], + refold_transform((self.buffer.examples)[:len(self.buffer)][buff_val_mask]) + ]) + +def reservoir(num_seen_examples: int, buffer_size: int) -> int: + """ + Reservoir sampling algorithm. + :param num_seen_examples: the number of seen examples + :param buffer_size: the maximum buffer size + :return: the target index if the current image is sampled, else -1 + """ + if num_seen_examples < buffer_size: + return num_seen_examples + + rand = np.random.randint(0, num_seen_examples + 1) + if rand < buffer_size: + return rand + else: + return -1 + + +def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: + return num_seen_examples % buffer_portion_size + task * buffer_portion_size + + +class Buffer: + """ + The memory buffer of rehearsal method. + """ + def __init__(self, buffer_size, device, n_tasks=None, mode='reservoir'): + assert mode in ['ring', 'reservoir'] + self.buffer_size = buffer_size + self.device = device + self.num_seen_examples = 0 + self.functional_index = eval(mode) + if mode == 'ring': + assert n_tasks is not None + self.task_number = n_tasks + self.buffer_portion_size = buffer_size // n_tasks + self.attributes = ['examples', 'labels', 'logits', 'task_labels'] + + def to(self, device): + self.device = device + for attr_str in self.attributes: + if hasattr(self, attr_str): + setattr(self, attr_str, getattr(self, attr_str).to(device)) + return self + + def __len__(self): + return min(self.num_seen_examples, self.buffer_size) + + + def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, + logits: torch.Tensor, task_labels: torch.Tensor) -> None: + """ + Initializes just the required tensors. + :param examples: tensor containing the images + :param labels: tensor containing the labels + :param logits: tensor containing the outputs of the network + :param task_labels: tensor containing the task labels + """ + for attr_str in self.attributes: + attr = eval(attr_str) + if attr is not None and not hasattr(self, attr_str): + typ = torch.int64 if attr_str.endswith('els') else torch.float32 + setattr(self, attr_str, torch.zeros((self.buffer_size, + *attr.shape[1:]), dtype=typ, device=self.device)) + + def add_data(self, examples, labels=None, logits=None, task_labels=None): + """ + Adds the data to the memory buffer according to the reservoir strategy. + :param examples: tensor containing the images + :param labels: tensor containing the labels + :param logits: tensor containing the outputs of the network + :param task_labels: tensor containing the task labels + :return: + """ + if not hasattr(self, 'examples'): + self.init_tensors(examples, labels, logits, task_labels) + + for i in range(examples.shape[0]): + index = reservoir(self.num_seen_examples, self.buffer_size) + self.num_seen_examples += 1 + if index >= 0: + self.examples[index] = examples[i].to(self.device) + if labels is not None: + self.labels[index] = labels[i].to(self.device) + if logits is not None: + self.logits[index] = logits[i].to(self.device) + if task_labels is not None: + self.task_labels[index] = task_labels[i].to(self.device) + + def get_data(self, size: int, transform: transforms=None, return_index=False) -> Tuple: + """ + Random samples a batch of size items. + :param size: the number of requested items + :param transform: the transformation to be applied (data augmentation) + :return: + """ + if size > min(self.num_seen_examples, self.examples.shape[0]): + size = min(self.num_seen_examples, self.examples.shape[0]) + + choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]), + size=size, replace=False) + if transform is None: transform = lambda x: x + ret_tuple = (torch.stack([transform(ee.cpu()) + for ee in self.examples[choice]]).to(self.device),) + for attr_str in self.attributes[1:]: + if hasattr(self, attr_str): + attr = getattr(self, attr_str) + ret_tuple += (attr[choice],) + + if not return_index: + return ret_tuple + else: + return (torch.tensor(choice).to(self.device), ) + ret_tuple + + return ret_tuple + + def get_data_by_index(self, indexes, transform: transforms=None) -> Tuple: + """ + Returns the data by the given index. + :param index: the index of the item + :param transform: the transformation to be applied (data augmentation) + :return: + """ + if transform is None: transform = lambda x: x + ret_tuple = (torch.stack([transform(ee.cpu()) + for ee in self.examples[indexes]]).to(self.device),) + for attr_str in self.attributes[1:]: + if hasattr(self, attr_str): + attr = getattr(self, attr_str).to(self.device) + ret_tuple += (attr[indexes],) + return ret_tuple + + + def is_empty(self) -> bool: + """ + Returns true if the buffer is empty, false otherwise. + """ + if self.num_seen_examples == 0: + return True + else: + return False + + def get_all_data(self, transform: transforms=None) -> Tuple: + """ + Return all the items in the memory buffer. + :param transform: the transformation to be applied (data augmentation) + :return: a tuple with all the items in the memory buffer + """ + if transform is None: transform = lambda x: x + ret_tuple = (torch.stack([transform(ee.cpu()) + for ee in self.examples]).to(self.device),) + for attr_str in self.attributes[1:]: + if hasattr(self, attr_str): + attr = getattr(self, attr_str) + ret_tuple += (attr,) + return ret_tuple + + def empty(self) -> None: + """ + Set all the tensors to None. + """ + for attr_str in self.attributes: + if hasattr(self, attr_str): + delattr(self, attr_str) + self.num_seen_examples = 0 diff --git a/utils/cutmix.py b/utils/cutmix.py new file mode 100644 index 0000000..ed12ec0 --- /dev/null +++ b/utils/cutmix.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F +import numpy as np + +def rand_bbox(size, lam): + W = size[2] + H = size[3] + cut_rat = np.sqrt(1. - lam) + cut_w = np.int(W * cut_rat) + cut_h = np.int(H * cut_rat) + + # uniform + cx = np.random.randint(W) + cy = np.random.randint(H) + + bbx1 = np.clip(cx - cut_w // 2, 0, W) + bby1 = np.clip(cy - cut_h // 2, 0, H) + bbx2 = np.clip(cx + cut_w // 2, 0, W) + bby2 = np.clip(cy + cut_h // 2, 0, H) + + return bbx1, bby1, bbx2, bby2 + +def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5): + assert(alpha > 0) + # generate mixed sample + lam = np.random.beta(alpha, alpha) + + batch_size = x.size()[0] + index = torch.randperm(batch_size) + + if torch.cuda.is_available(): + index = index.cuda() + + y_a, y_b = y, y[index] + bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam) + x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] + + # adjust lambda to exactly match pixel ratio + lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) + return x, y_a, y_b, lam diff --git a/utils/data.py b/utils/data.py new file mode 100644 index 0000000..d316c44 --- /dev/null +++ b/utils/data.py @@ -0,0 +1,256 @@ +import numpy as np +from torchvision import datasets, transforms +from utils.toolkit import split_images_labels + + +class iData(object): + train_trsf = [] + test_trsf = [] + common_trsf = [] + class_order = None + + +class iCIFAR10(iData): + use_path = False + train_trsf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ColorJitter(brightness=63/255) + ] + test_trsf = [] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)), + ] + + class_order = np.arange(10).tolist() + + def download_data(self): + train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True) + test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True) + self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets) + self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets) + + +class iCIFAR100(iData): + use_path = False + train_trsf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63/255) + ] + test_trsf = [] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)), + ] + + class_order = np.arange(100).tolist() + + def download_data(self): + train_dataset = datasets.cifar.CIFAR100('./data', train=True, download=True) + test_dataset = datasets.cifar.CIFAR100('./data', train=False, download=True) + self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets) + self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets) + +class iCIFAR100_224(iCIFAR100): + train_trsf = [ + transforms.RandomResizedCrop(224, interpolation=3), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63/255) + ] + test_trsf = [ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + ] + +class iImageNet1000(iData): + use_path = True + train_trsf = [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63/255) + ] + test_trsf = [ + transforms.Resize(256), + transforms.CenterCrop(224), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + + class_order = np.arange(1000).tolist() + + def download_data(self): + assert 0,"You should specify the folder of your dataset" + train_dir = '[DATA-PATH]/train/' + test_dir = '[DATA-PATH]/val/' + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + +class iImageNet100(iData): + use_path = True + train_trsf = [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + ] + test_trsf = [ + transforms.Resize(256), + transforms.CenterCrop(224), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + + class_order = np.arange(1000).tolist() + + def download_data(self): + train_dir = 'data/imagenet100/train/' + test_dir = 'data/imagenet100/val/' + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + + +class iImageNetR(iData): + use_path = True + train_trsf = [ + transforms.RandomResizedCrop(224, interpolation=3), + transforms.RandomHorizontalFlip(), + ] + test_trsf = [ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + + class_order = np.arange(1000).tolist() + + def download_data(self): + train_dir = 'data/imagenet-r/train/' + test_dir = 'data/imagenet-r/val/' + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + +class iCUB200_224(iData): + use_path = True + train_trsf = [ + transforms.Resize((300, 300), interpolation=3), + transforms.RandomCrop((224, 224)), + transforms.RandomHorizontalFlip(), + ] + test_trsf = [ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + class_order = np.arange(1000).tolist() + + def download_data(self): + train_dir = 'data/cub_200/train/' + test_dir = 'data/cub_200/val/' + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + +class iCARS196_224(iData): + use_path = True + train_trsf = [ + transforms.Resize((300, 300), interpolation=3), + transforms.RandomCrop((224, 224)), + transforms.RandomHorizontalFlip(), + ] + test_trsf = [ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + class_order = np.arange(1000).tolist() + + def download_data(self): + train_dir = 'data/cars196/train/' + test_dir = 'data/cars196/val/' + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + + +class iResisc45_224(iData): + use_path = True + train_trsf = [ + transforms.Resize((300, 300), interpolation=3), + transforms.RandomCrop((224, 224)), + transforms.RandomHorizontalFlip(), + ] + test_trsf = [ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + class_order = np.arange(1000).tolist() + + def download_data(self): + train_dir = 'data/resisc45/train/' + test_dir = 'data/resisc45/val/' + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + + + +class iSketch345_224(iData): + use_path = True + train_trsf = [ + transforms.Resize((300, 300), interpolation=3), + transforms.RandomCrop((224, 224)), + transforms.RandomHorizontalFlip(), + ] + test_trsf = [ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + ] + common_trsf = [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + class_order = np.arange(1000).tolist() + + def download_data(self): + train_dir = 'data/sketch345/train/' + test_dir = 'data/sketch345/val/' + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) \ No newline at end of file diff --git a/utils/data_manager.py b/utils/data_manager.py new file mode 100644 index 0000000..1d3d740 --- /dev/null +++ b/utils/data_manager.py @@ -0,0 +1,245 @@ +import logging +import numpy as np +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000, iCIFAR100_224, iImageNetR, iCUB200_224, iResisc45_224, iCARS196_224, iSketch345_224 +from copy import deepcopy +import random + +class DataManager(object): + def __init__(self, dataset_name, shuffle, seed, init_cls, increment): + self.dataset_name = dataset_name + self._setup_data(dataset_name, shuffle, seed) + assert init_cls <= len(self._class_order), 'No enough classes.' + self._increments = [init_cls] + while sum(self._increments) + increment < len(self._class_order): + self._increments.append(increment) + offset = len(self._class_order) - sum(self._increments) + if offset > 0: + self._increments.append(offset) + + @property + def nb_tasks(self): + return len(self._increments) + + def get_task_size(self, task): + return self._increments[task] + + def get_dataset(self, indices, source, mode, appendent=None, ret_data=False, with_raw=False, with_noise=False): + if source == 'train': + x, y = self._train_data, self._train_targets + elif source == 'test': + x, y = self._test_data, self._test_targets + else: + raise ValueError('Unknown data source {}.'.format(source)) + + if mode == 'train': + trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) + elif mode == 'flip': + trsf = transforms.Compose([*self._test_trsf, transforms.RandomHorizontalFlip(p=1.), *self._common_trsf]) + elif mode == 'test': + trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) + else: + raise ValueError('Unknown mode {}.'.format(mode)) + + data, targets = [], [] + for idx in indices: + class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1) + data.append(class_data) + targets.append(class_targets) + + if appendent is not None and len(appendent) != 0: + appendent_data, appendent_targets = appendent + data.append(appendent_data) + targets.append(appendent_targets) + + data, targets = np.concatenate(data), np.concatenate(targets) + + if ret_data: + return data, targets, DummyDataset(data, targets, trsf, self.use_path, with_raw, with_noise) + else: + return DummyDataset(data, targets, trsf, self.use_path, with_raw, with_noise) + + def get_dataset_with_split(self, indices, source, mode, appendent=None, val_samples_per_class=0): + if source == 'train': + x, y = self._train_data, self._train_targets + elif source == 'test': + x, y = self._test_data, self._test_targets + else: + raise ValueError('Unknown data source {}.'.format(source)) + + if mode == 'train': + trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) + elif mode == 'test': + trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) + else: + raise ValueError('Unknown mode {}.'.format(mode)) + + train_data, train_targets = [], [] + val_data, val_targets = [], [] + for idx in indices: + class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1) + val_indx = np.random.choice(len(class_data), val_samples_per_class, replace=False) + train_indx = list(set(np.arange(len(class_data))) - set(val_indx)) + val_data.append(class_data[val_indx]) + val_targets.append(class_targets[val_indx]) + train_data.append(class_data[train_indx]) + train_targets.append(class_targets[train_indx]) + + if appendent is not None: + appendent_data, appendent_targets = appendent + for idx in range(0, int(np.max(appendent_targets))+1): + append_data, append_targets = self._select(appendent_data, appendent_targets, + low_range=idx, high_range=idx+1) + val_indx = np.random.choice(len(append_data), val_samples_per_class, replace=False) + train_indx = list(set(np.arange(len(append_data))) - set(val_indx)) + val_data.append(append_data[val_indx]) + val_targets.append(append_targets[val_indx]) + train_data.append(append_data[train_indx]) + train_targets.append(append_targets[train_indx]) + + train_data, train_targets = np.concatenate(train_data), np.concatenate(train_targets) + val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets) + + return DummyDataset(train_data, train_targets, trsf, self.use_path), \ + DummyDataset(val_data, val_targets, trsf, self.use_path) + + def _setup_data(self, dataset_name, shuffle, seed): + idata = _get_idata(dataset_name) + idata.download_data() + + # Data + self._train_data, self._train_targets = idata.train_data, idata.train_targets + self._test_data, self._test_targets = idata.test_data, idata.test_targets + self.use_path = idata.use_path + + # Transforms + self._train_trsf = idata.train_trsf + self._test_trsf = idata.test_trsf + self._common_trsf = idata.common_trsf + + # Order + order = [i for i in range(len(np.unique(self._train_targets)))] + if shuffle: + np.random.seed(seed) + order = np.random.permutation(len(order)).tolist() + else: + order = idata.class_order + self._class_order = order + logging.info(self._class_order) + + # Map indices + self._train_targets = _map_new_class_index(self._train_targets, self._class_order) + self._test_targets = _map_new_class_index(self._test_targets, self._class_order) + + def _select(self, x, y, low_range, high_range): + idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] + return x[idxes], y[idxes] + + +class DummyDataset(Dataset): + def __init__(self, images, labels, trsf, use_path=False, with_raw=False, with_noise=False): + assert len(images) == len(labels), 'Data size error!' + self.images = images + self.labels = labels + self.trsf = trsf + self.use_path = use_path + self.with_raw = with_raw + if use_path and with_raw: + self.raw_trsf = transforms.Compose([transforms.Resize((500, 500)), transforms.ToTensor()]) + else: + self.raw_trsf = transforms.Compose([transforms.ToTensor()]) + if with_noise: + class_list = np.unique(self.labels) + self.ori_labels = deepcopy(labels) + for cls in class_list: + random_target = class_list.tolist() + random_target.remove(cls) + tindx = [i for i, x in enumerate(self.ori_labels) if x == cls] + for i in tindx[:round(len(tindx)*0.2)]: + self.labels[i] = random.choice(random_target) + + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + if self.use_path: + load_image = pil_loader(self.images[idx]) + image = self.trsf(load_image) + else: + load_image = Image.fromarray(self.images[idx]) + image = self.trsf(load_image) + label = self.labels[idx] + if self.with_raw: + return idx, image, label, self.raw_trsf(load_image) + return idx, image, label + + +def _map_new_class_index(y, order): + return np.array(list(map(lambda x: order.index(x), y))) + + +def _get_idata(dataset_name): + name = dataset_name.lower() + if name == 'cifar10': + return iCIFAR10() + elif name == 'cifar100': + return iCIFAR100() + elif name == 'cifar100_224': + return iCIFAR100_224() + elif name == 'imagenet1000': + return iImageNet1000() + elif name == "imagenet100": + return iImageNet100() + elif name == "imagenet-r": + return iImageNetR() + elif name == 'cub200_224': + return iCUB200_224() + elif name == 'resisc45': + return iResisc45_224() + elif name == 'cars196_224': + return iCARS196_224() + elif name == 'sketch345_224': + return iSketch345_224() + else: + raise NotImplementedError('Unknown dataset {}.'.format(dataset_name)) + + +def pil_loader(path): + ''' + Ref: + https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder + ''' + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + +def accimage_loader(path): + ''' + Ref: + https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder + accimage is an accelerated Image loader and preprocessor leveraging Intel IPP. + accimage is available on conda-forge. + ''' + import accimage + try: + return accimage.Image(path) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + return pil_loader(path) + + +def default_loader(path): + ''' + Ref: + https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder + ''' + from torchvision import get_image_backend + if get_image_backend() == 'accimage': + return accimage_loader(path) + else: + return pil_loader(path) diff --git a/utils/factory.py b/utils/factory.py new file mode 100644 index 0000000..1535055 --- /dev/null +++ b/utils/factory.py @@ -0,0 +1,9 @@ +import torch +from models.slca import SLCA + +def get_model(model_name, args): + name = model_name.lower() + if 'slca' in name: + return SLCA(args) + else: + assert 0 diff --git a/utils/inc_net.py b/utils/inc_net.py new file mode 100644 index 0000000..5350e8b --- /dev/null +++ b/utils/inc_net.py @@ -0,0 +1,587 @@ +import copy +import torch +from torch import nn +from convs.cifar_resnet import resnet32 +from convs.resnet import resnet18, resnet34, resnet50 +from convs.linears import SimpleLinear, SplitCosineLinear, CosineLinear, MyContinualClassifier, SimpleContinualLinear +from convs.vits import vit_base_patch16_224_in21k, vit_base_patch16_224_mocov3 +import torch.nn.functional as F + +def get_convnet(convnet_type, pretrained=False): + name = convnet_type.lower() + if name == 'resnet32': + return resnet32() + elif name == 'resnet18': + return resnet18(pretrained=pretrained) + elif name == 'resnet18_cifar': + return resnet18(pretrained=pretrained, cifar=True) + elif name == 'resnet18_cifar_cos': + return resnet18(pretrained=pretrained, cifar=True, no_last_relu=True) + elif name == 'resnet34': + return resnet34(pretrained=pretrained) + elif name == 'resnet50': + return resnet50(pretrained=pretrained) + elif name == 'vit-b-p16': + return vit_base_patch16_224_in21k(pretrained=pretrained) + elif name == 'vit-b-p16-mocov3': + return vit_base_patch16_224_mocov3(pretrained=True) + else: + raise NotImplementedError('Unknown type {}'.format(convnet_type)) + + +class BaseNet(nn.Module): + + def __init__(self, convnet_type, pretrained): + super(BaseNet, self).__init__() + + self.convnet = get_convnet(convnet_type, pretrained) + self.fc = None + + @property + def feature_dim(self): + return self.convnet.out_dim + + def extract_vector(self, x): + return self.convnet(x)['features'] + + def forward(self, x): + x = self.convnet(x) + out = self.fc(x['features']) + ''' + { + 'fmaps': [x_1, x_2, ..., x_n], + 'features': features + 'logits': logits + } + ''' + out.update(x) + + return out + + def update_fc(self, nb_classes): + pass + + def generate_fc(self, in_dim, out_dim): + pass + + def copy(self): + return copy.deepcopy(self) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + self.eval() + + return self + + +class IncrementalNet(BaseNet): + + def __init__(self, convnet_type, pretrained, gradcam=False, use_aux=False): + super().__init__(convnet_type, pretrained) + self.gradcam = gradcam + if hasattr(self, 'gradcam') and self.gradcam: + self._gradcam_hooks = [None, None] + self.set_gradcam_hook() + self.use_aux = use_aux + self.aux_fc = None + + def update_fc(self, nb_classes, new_task_size=-1): + if self.use_aux: + assert new_task_size!=-1 + old_tasks = (nb_classes-new_task_size)//new_task_size # WARNING: does not consider the case of difference task sizs + #aux_fc=self.generate_fc(self.feature_dim, new_task_size+1, bias=True) + aux_fc=self.generate_fc(self.feature_dim, new_task_size+old_tasks, bias=True) + if self.aux_fc is not None: + old_aux_w = copy.deepcopy(self.aux_fc.weight.data) + aux_fc.weight.data[:old_tasks-1] = old_aux_w[:old_tasks-1] + compress_old_w = old_aux_w[old_tasks-1:].mean(0) + aux_fc.weight.data[old_tasks-1] = compress_old_w + old_bias = copy.deepcopy(self.aux_fc.bias.data) + aux_fc.bias.data[:old_tasks-1] = old_bias[:old_tasks-1] + compress_old_b = old_bias[old_tasks-1:].mean(0) + aux_fc.bias.data[old_tasks-1] = compress_old_b + del self.aux_fc + self.aux_fc = aux_fc + + nb_old_classes = nb_classes-new_task_size + fc = MyContinualClassifier(self.feature_dim, nb_old_classes, new_task_size) + if self.fc is not None: + weights = [copy.deepcopy(self.fc.heads[hi].weight.data) for hi in range(len(self.fc.heads))] + weights = torch.cat(weights, dim=0) + fc.heads[0].weight.data = weights + fc.old_head.weight.data = weights + bias = [copy.deepcopy(self.fc.heads[hi].bias.data) for hi in range(len(self.fc.heads))] + bias = torch.cat(bias, dim=0) + fc.heads[0].bias.data = bias + fc.old_head.bias.data = bias + fc.old_head.weight.requires_grad=False + fc.old_head.bias.requires_grad=False + #fc = self.generate_fc(self.feature_dim, nb_classes) + # if self.fc is not None: + # nb_output = self.fc.out_features + # weight = copy.deepcopy(self.fc.weight.data) + # fc.weight.data[:nb_output] = weight + # if self.fc.bias is not None: + # bias = copy.deepcopy(self.fc.bias.data) + # fc.bias.data[:nb_output] = bias + del self.fc + self.fc = fc + + def weight_align(self, increment, align_avg=False): + #weights=self.fc.weight.data.detach() + #newnorm=(torch.norm(weights[-increment:,:],p=2,dim=1)) + #oldnorm=(torch.norm(weights[:-increment,:],p=2,dim=1)) + newnorm=(torch.norm(self.fc.heads[1].weight,p=2,dim=1)) + oldnorm=(torch.norm(self.fc.heads[0].weight,p=2,dim=1)) + #oldnorm=(torch.norm(self.fc.old_head.weight,p=2,dim=1)) + meannew=torch.mean(newnorm) + meanold=torch.mean(oldnorm) + if align_avg: + avgnorm = (meannew+meanold)/2 + #gamma1 = avgnorm/meanold + gamma2 = avgnorm/meannew + #self.fc.weight.data[:-increment,:]*=gamma1 + self.fc.weight.data[-increment:,:]*=gamma2 + #return [gamma1, gamma2] + return gamma2 + gamma=meanold/meannew + #gamma = 0.9 + self.fc.heads[1].weight.data*=gamma + #self.fc.heads[0].weight.data = self.fc.old_head.weight.data + #self.fc.heads[0].bias.data = self.fc.old_head.bias.data + return gamma + + def generate_fc(self, in_dim, out_dim, bias=True): + fc = SimpleLinear(in_dim, out_dim, bias=bias) + + return fc + + def forward_head(self, x): + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + def forward(self, x, backbone_only=False, with_aux=False, fc_only=False): + if fc_only: + return self.forward_head(x) + x = self.convnet(x) + if backbone_only: + return x['fmaps'][-1] + out = self.fc(x['features'], with_aux=with_aux) + out.update(x) + if hasattr(self, 'gradcam') and self.gradcam: + out['gradcam_gradients'] = self._gradcam_gradients + out['gradcam_activations'] = self._gradcam_activations + #if self.use_aux: + # out_aux = self.aux_fc(x['features'])["logits"] + # out.update({"aux_logits": out_aux}) + return out + + def unset_gradcam_hook(self): + self._gradcam_hooks[0].remove() + self._gradcam_hooks[1].remove() + self._gradcam_hooks[0] = None + self._gradcam_hooks[1] = None + self._gradcam_gradients, self._gradcam_activations = [None], [None] + + def set_gradcam_hook(self): + self._gradcam_gradients, self._gradcam_activations = [None], [None] + + def backward_hook(module, grad_input, grad_output): + self._gradcam_gradients[0] = grad_output[0] + return None + + def forward_hook(module, input, output): + self._gradcam_activations[0] = output + return None + + self._gradcam_hooks[0] = self.convnet.last_conv.register_backward_hook(backward_hook) + self._gradcam_hooks[1] = self.convnet.last_conv.register_forward_hook(forward_hook) + +class DerppIncrementalNet(BaseNet): + + def __init__(self, convnet_type, pretrained): + super().__init__(convnet_type, pretrained) + + def update_fc(self, nb_classes): + #fc = self.generate_fc(self.feature_dim, nb_classes) + #if self.fc is not None: + # nb_output = self.fc.out_features + # weight = copy.deepcopy(self.fc.weight.data) + # bias = copy.deepcopy(self.fc.bias.data) + # fc.weight.data[:nb_output] = weight + # fc.bias.data[:nb_output] = bias + + #del self.fc + #self.fc = fc + if self.fc is None: + self.fc = self.generate_fc(self.feature_dim, nb_classes) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + + return fc + + def forward(self, x): + x = self.convnet(x) + out = self.fc(x['features']) + out.update(x) + + return out + + +class CosineIncrementalNet(BaseNet): + + def __init__(self, convnet_type, pretrained, nb_proxy=1): + super().__init__(convnet_type, pretrained) + self.nb_proxy = nb_proxy + + def update_fc(self, nb_classes, task_num): + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + if task_num == 1: + #fc.fc1.weight.data = self.fc.weight.data + fc.weight.data[:self.fc.weight.data.size(0)] = self.fc.weight.data + fc.sigma.data[:self.fc.weight.data.size(0)] = self.fc.sigma.data + else: + prev_out_features1 = self.fc.fc1.out_features + fc.fc1.weight.data[:prev_out_features1] = self.fc.fc1.weight.data + fc.fc1.weight.data[prev_out_features1:] = self.fc.fc2.weight.data + fc.sigma.data = self.fc.sigma.data + + del self.fc + self.fc = fc + + def generate_fc(self, in_dim, out_dim): + #if self.fc is None: + # fc = CosineLinear(in_dim, out_dim, self.nb_proxy, to_reduce=True) + #else: + # prev_out_features = self.fc.out_features // self.nb_proxy + # # prev_out_features = self.fc.out_features + # fc = SplitCosineLinear(in_dim, prev_out_features, out_dim - prev_out_features, self.nb_proxy) + fc = CosineLinear(in_dim, out_dim, self.nb_proxy, to_reduce=True) + return fc + +class GdumbIncrementalNet(BaseNet): + def __init__(self, convnet_type, pretrained, new_norm=False): + super().__init__(convnet_type, pretrained) + self.init_convnet = copy.deepcopy(self.convnet) + if new_norm: + self.fc_norm = nn.LayerNorm(768) + else: + self.fc_norm = nn.Identity() + + def forward(self, x, bcb_no_grad=False): + if bcb_no_grad: + with torch.no_grad(): + x = self.convnet(x) + else: + x = self.convnet(x) + out = self.fc(self.fc_norm(x['features'])) + out.update(x) + + return out + + def update_fc(self, nb_classes): + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + del self.fc + self.fc = fc + self.convnet = copy.deepcopy(self.init_convnet) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim, init_method='normal') + + return fc + + + +class BiasLayer(nn.Module): + def __init__(self): + super(BiasLayer, self).__init__() + self.alpha = nn.Parameter(torch.ones(1, requires_grad=True)) + self.beta = nn.Parameter(torch.zeros(1, requires_grad=True)) + + def forward(self, x, low_range, high_range): + ret_x = x.clone() + ret_x[:, low_range:high_range] = self.alpha * x[:, low_range:high_range] + self.beta + return ret_x + + def get_params(self): + return (self.alpha.item(), self.beta.item()) + + +class IncrementalNetWithBias(BaseNet): + def __init__(self, convnet_type, pretrained, bias_correction=False, feat_expand=False, fc_with_ln=False): + super().__init__(convnet_type, pretrained) + + # Bias layer + self.bias_correction = bias_correction + self.feat_expand = feat_expand + self.bias_layers = nn.ModuleList([]) + self.task_sizes = [] + if feat_expand: + self.expand_base_block = copy.deepcopy(self.convnet.blocks[-1]) + self.fc_with_ln = fc_with_ln + + def forward(self, x, bcb_no_grad=False, fc_only=False): + if fc_only: + fc_out = self.fc(x) + if self.bias_correction: + logits = fc_out['logits'] + for i, layer in enumerate(self.bias_layers): + logits = layer(logits, sum(self.task_sizes[:i]), sum(self.task_sizes[:i+1])) + fc_out['logits'] = logits + return fc_out + if bcb_no_grad: + with torch.no_grad(): + x = self.convnet(x) + else: + x = self.convnet(x) + if self.feat_expand and not isinstance(x['features'], list): + x['features'] = [x['features']] + out = self.fc(x['features']) + if self.bias_correction: + logits = out['logits'] + if bcb_no_grad: + logits = logits.detach() + for i, layer in enumerate(self.bias_layers): + logits = layer(logits, sum(self.task_sizes[:i]), sum(self.task_sizes[:i+1])) + out['logits'] = logits + + out.update(x) + + return out + + def update_fc(self, nb_classes, freeze_old=False): + #fc = self.generate_fc(self.feature_dim, nb_classes) + #if self.fc is not None: + # nb_output = self.fc.out_features + # weight = copy.deepcopy(self.fc.weight.data) + # bias = copy.deepcopy(self.fc.bias.data) + # fc.weight.data[:nb_output] = weight + # fc.bias.data[:nb_output] = bias + + #del self.fc + #self.fc = fc + + if self.fc is None: + self.fc = self.generate_fc(self.feature_dim, nb_classes) + else: + self.fc.update(nb_classes, freeze_old=freeze_old) + if self.feat_expand: + for p in self.convnet.parameters(): + p.requires_grad=False + self.convnet.extra_blocks.append(nn.Sequential(copy.deepcopy(self.expand_base_block), nn.LayerNorm(self.feature_dim))) + + #new_task_size = nb_classes - sum(self.task_sizes) + new_task_size = nb_classes + self.task_sizes.append(new_task_size) + self.bias_layers.append(BiasLayer()) + + def generate_fc(self, in_dim, out_dim): + #fc = SimpleLinear(in_dim, out_dim, init_method='normal') + fc = SimpleContinualLinear(in_dim, out_dim, feat_expand=self.feat_expand, with_norm=self.fc_with_ln) + + return fc + + def extract_vector(self, x): + features = self.convnet(x)['features'] + if isinstance(features, list): + features = torch.stack(features, 0).mean(0) + + return features + + def extract_layerwise_vector(self, x): + with torch.no_grad(): + features = self.convnet(x, layer_feat=True)['features'] + for f_i in range(len(features)): + features[f_i] = features[f_i].mean(1).cpu().numpy() + return features + + def get_bias_params(self): + params = [] + for layer in self.bias_layers: + params.append(layer.get_params()) + + return params + + def unfreeze(self): + for param in self.parameters(): + param.requires_grad = True + + +class DERNet(nn.Module): + def __init__(self, convnet_type, pretrained): + super(DERNet,self).__init__() + self.convnet_type=convnet_type + self.convnets = nn.ModuleList() + self.pretrained=pretrained + self.out_dim=None + self.fc = None + self.aux_fc=None + self.task_sizes = [] + + @property + def feature_dim(self): + if self.out_dim is None: + return 0 + return self.out_dim*len(self.convnets) + + def extract_vector(self, x): + features = [convnet(x)['features'] for convnet in self.convnets] + features = torch.cat(features, 1) + return features + def forward(self, x): + features = [convnet(x)['features'] for convnet in self.convnets] + features = torch.cat(features, 1) + + out=self.fc(features) #{logics: self.fc(features)} + + aux_logits=self.aux_fc(features[:,-self.out_dim:])["logits"] + + out.update({"aux_logits":aux_logits,"features":features}) + return out + ''' + { + 'features': features + 'logits': logits + 'aux_logits':aux_logits + } + ''' + + def update_fc(self, nb_classes): + if len(self.convnets)==0: + self.convnets.append(get_convnet(self.convnet_type)) + else: + self.convnets.append(get_convnet(self.convnet_type)) + self.convnets[-1].load_state_dict(self.convnets[-2].state_dict()) + + if self.out_dim is None: + self.out_dim=self.convnets[-1].out_dim + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + bias = copy.deepcopy(self.fc.bias.data) + fc.weight.data[:nb_output,:self.feature_dim-self.out_dim] = weight + fc.bias.data[:nb_output] = bias + + del self.fc + self.fc = fc + + new_task_size = nb_classes - sum(self.task_sizes) + self.task_sizes.append(new_task_size) + + self.aux_fc=self.generate_fc(self.out_dim,new_task_size+1) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + + return fc + + def copy(self): + return copy.deepcopy(self) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + self.eval() + + return self + def freeze_conv(self): + for param in self.convnets.parameters(): + param.requires_grad = False + self.convnets.eval() + def weight_align(self, increment): + weights=self.fc.weight.data + newnorm=(torch.norm(weights[-increment:,:],p=2,dim=1)) + oldnorm=(torch.norm(weights[:-increment,:],p=2,dim=1)) + meannew=torch.mean(newnorm) + meanold=torch.mean(oldnorm) + gamma=meanold/meannew + print('alignweights,gamma=',gamma) + self.fc.weight.data[-increment:,:]*=gamma + +class SimpleCosineIncrementalNet(BaseNet): + + def __init__(self, convnet_type, pretrained): + super().__init__(convnet_type, pretrained) + + def update_fc(self, nb_classes, nextperiod_initialization): + fc = self.generate_fc(self.feature_dim, nb_classes).cuda() + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + fc.sigma.data=self.fc.sigma.data + if nextperiod_initialization is not None: + + weight=torch.cat([weight,nextperiod_initialization]) + fc.weight=nn.Parameter(weight) + del self.fc + self.fc = fc + + + def generate_fc(self, in_dim, out_dim): + fc = CosineLinear(in_dim, out_dim) + return fc + +class FinetuneIncrementalNet(BaseNet): + + def __init__(self, convnet_type, pretrained, fc_with_ln=False, fc_with_mlp=False, with_task_embed=False, fc_with_preproj=False): + super().__init__(convnet_type, pretrained) + self.old_fc = None + self.fc_with_ln = fc_with_ln + self.fc_with_mlp = fc_with_mlp + self.with_task_embed = with_task_embed + self.fc_with_preproj = fc_with_preproj + + + def extract_layerwise_vector(self, x, pool=True): + with torch.no_grad(): + features = self.convnet(x, layer_feat=True)['features'] + for f_i in range(len(features)): + if pool: + features[f_i] = features[f_i].mean(1).cpu().numpy() + else: + features[f_i] = features[f_i][:, 0].cpu().numpy() + return features + + + def update_fc(self, nb_classes, freeze_old=True): + if self.fc is None: + self.fc = self.generate_fc(self.feature_dim, nb_classes) + else: + self.fc.update(nb_classes, freeze_old=freeze_old) + + def save_old_fc(self): + if self.old_fc is None: + self.old_fc = copy.deepcopy(self.fc) + else: + self.old_fc.heads.append(copy.deepcopy(self.fc.heads[-1])) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleContinualLinear(in_dim, out_dim, with_norm=self.fc_with_ln, with_mlp=self.fc_with_mlp, with_task_embed=self.with_task_embed, with_preproj=self.fc_with_preproj) + + return fc + + def forward(self, x, bcb_no_grad=False, fc_only=False): + if fc_only: + fc_out = self.fc(x) + if self.old_fc is not None: + old_fc_logits = self.old_fc(x)['logits'] + fc_out['old_logits'] = old_fc_logits + return fc_out + if bcb_no_grad: + with torch.no_grad(): + x = self.convnet(x) + else: + x = self.convnet(x) + out = self.fc(x['features']) + out.update(x) + + return out + + diff --git a/utils/net_linear_wapper.py b/utils/net_linear_wapper.py new file mode 100644 index 0000000..9331ad6 --- /dev/null +++ b/utils/net_linear_wapper.py @@ -0,0 +1,14 @@ +import torch + +class LinearWapper(nn.Module): + def __init__(self, model): + super(LinearWapper, self).__init__() + self.reset_parameters() + + def reset_parameters(self): + for m in self.modules(): + nn.init.kaiming_uniform_(m.weight, nonlinearity='linear') + nn.init.constant_(m.bias, 0) + + def forward(self, input): + return {'logits': F.linear(input, self.weight, self.bias)} diff --git a/utils/toolkit.py b/utils/toolkit.py new file mode 100644 index 0000000..e789945 --- /dev/null +++ b/utils/toolkit.py @@ -0,0 +1,58 @@ +import os +import numpy as np +import torch + + +def count_parameters(model, trainable=False): + if trainable: + return sum(p.numel() for p in model.parameters() if p.requires_grad) + return sum(p.numel() for p in model.parameters()) + + +def tensor2numpy(x): + return x.cpu().data.numpy() if x.is_cuda else x.data.numpy() + + +def target2onehot(targets, n_classes): + onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) + onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.) + return onehot + + +def makedirs(path): + if not os.path.exists(path): + os.makedirs(path) + + +def accuracy(y_pred, y_true, nb_old, increment=10): + assert len(y_pred) == len(y_true), 'Data length error.' + all_acc = {} + all_acc['total'] = np.around((y_pred == y_true).sum()*100 / len(y_true), decimals=2) + + # Grouped accuracy + for class_id in range(0, np.max(y_true), increment): + idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0] + label = '{}-{}'.format(str(class_id).rjust(2, '0'), str(class_id+increment-1).rjust(2, '0')) + all_acc[label] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2) + + # Old accuracy + idxes = np.where(y_true < nb_old)[0] + all_acc['old'] = 0 if len(idxes) == 0 else np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), + decimals=2) + + # New accuracy + idxes = np.where(y_true >= nb_old)[0] + all_acc['new'] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2) + + return all_acc + + +def split_images_labels(imgs): + # split trainset.imgs in ImageFolder + images = [] + labels = [] + for item in imgs: + images.append(item[0]) + labels.append(item[1]) + + return np.array(images), np.array(labels)