Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support KNN benchmark #243

Merged
merged 7 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions configs/benchmarks/classification/knn_imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
data_source = 'ImageNet'
dataset_type = 'SingleViewDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
pipeline = [
dict(type='Resize', size=256),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]

data = dict(
samples_per_gpu=256,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/imagenet/train',
ann_file='data/imagenet/meta/train.txt',
),
pipeline=pipeline),
val=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
),
pipeline=pipeline))
8 changes: 5 additions & 3 deletions mmselfsup/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .accuracy import Accuracy, accuracy
from .extract_process import ExtractProcess
from .extract_process import ExtractProcess, MultiExtractProcess
from .gather_layer import GatherLayer
from .knn_classifier import knn_classifier
from .multi_pooling import MultiPooling
from .multi_prototypes import MultiPrototypes
from .position_embedding import build_2d_sincos_position_embedding
from .sobel import Sobel

__all__ = [
'Accuracy', 'accuracy', 'ExtractProcess', 'GatherLayer', 'MultiPooling',
'MultiPrototypes', 'Sobel', 'build_2d_sincos_position_embedding', 'Mixup'
'Accuracy', 'accuracy', 'ExtractProcess', 'MultiExtractProcess',
'GatherLayer', 'knn_classifier', 'MultiPooling', 'MultiPrototypes',
'build_2d_sincos_position_embedding', 'Sobel'
]
37 changes: 35 additions & 2 deletions mmselfsup/models/utils/extract_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner import get_dist_info

from mmselfsup.utils.collect import (dist_forward_collect,
Expand All @@ -7,8 +8,40 @@


class ExtractProcess(object):
"""Extraction process for `extract.py` and `tsne_visualization.py` in
tools.
"""Global average-pooled feature extraction process."""

def __init__(self):
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

def _forward_func(self, model, **x):
"""The forward function of extract process."""
backbone_feat = model(mode='extract', **x)
pooling_feat = self.avg_pool(backbone_feat[-1])
flat_feat = pooling_feat.view(pooling_feat.size(0), -1)
return dict(feat=flat_feat.cpu())

def extract(self, model, data_loader, distributed=False):
"""The extract function to apply forward function and choose
distributed or not."""
model.eval()

# the function sent to collect function
def func(**x):
return self._forward_func(model, **x)

if distributed:
rank, world_size = get_dist_info()
results = dist_forward_collect(func, data_loader, rank,
len(data_loader.dataset))
else:
results = nondist_forward_collect(func, data_loader,
len(data_loader.dataset))
return results


class MultiExtractProcess(object):
"""Multi-stage intermediate feature extraction process for `extract.py` and
`tsne_visualization.py` in tools.

Args:
pool_type (str): Pooling type in :class:`MultiPooling`. Options are
Expand Down
58 changes: 58 additions & 0 deletions mmselfsup/models/utils/knn_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This file is borrowed from
# https://github.com/facebookresearch/dino/blob/main/eval_knn.py

import torch
import torch.nn as nn


@torch.no_grad()
def knn_classifier(train_features,
train_labels,
test_features,
test_labels,
k,
T,
num_classes=1000):
top1, top5, total = 0.0, 0.0, 0
train_features = nn.functional.normalize(train_features, dim=1)
test_features = nn.functional.normalize(test_features, dim=1)
train_features = train_features.t()
num_test_images, num_chunks = test_labels.shape[0], 100
imgs_per_chunk = num_test_images // num_chunks
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
for idx in range(0, num_test_images, imgs_per_chunk):
# get the features for test images
features = test_features[idx:min((idx +
imgs_per_chunk), num_test_images), :]
targets = test_labels[idx:min((idx + imgs_per_chunk), num_test_images)]
batch_size = targets.shape[0]

# calculate the dot product and compute top-k neighbors
similarity = torch.mm(features, train_features)
distances, indices = similarity.topk(k, largest=True, sorted=True)
candidates = train_labels.view(1, -1).expand(batch_size, -1)
retrieved_neighbors = torch.gather(candidates, 1, indices)

retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
distances_transform = distances.clone().div_(T).exp_()
probs = torch.sum(
torch.mul(
retrieval_one_hot.view(batch_size, -1, num_classes),
distances_transform.view(batch_size, -1, 1),
),
1,
)
_, predictions = probs.sort(1, True)

# find the predictions that match the target
correct = predictions.eq(targets.data.view(-1, 1))
top1 = top1 + correct.narrow(1, 0, 1).sum().item()
top5 = top5 + correct.narrow(1, 0, min(
5, k)).sum().item() # top5 does not make sense if k < 5
total += targets.size(0)
top1 = top1 * 100.0 / total
top5 = top5 * 100.0 / total
return top1, top5
17 changes: 17 additions & 0 deletions tests/test_models/test_utils/test_knn_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmselfsup.models.utils import knn_classifier


def test_knn_classifier():
train_feats = torch.ones(200, 3)
train_labels = torch.ones(200).long()
test_feats = torch.ones(200, 3)
test_labels = torch.ones(200).long()
num_knn = [10, 20, 100, 200]
for k in num_knn:
top1, top5 = knn_classifier(train_feats, train_labels, test_feats,
test_labels, k, 0.07)
assert top1 == 100.
assert top5 == 100.
20 changes: 17 additions & 3 deletions tests/test_runtime/test_extract_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mmcv.parallel import MMDataParallel
from torch.utils.data import DataLoader, Dataset

from mmselfsup.models.utils import ExtractProcess
from mmselfsup.models.utils import ExtractProcess, MultiExtractProcess


class ExampleDataset(Dataset):
Expand Down Expand Up @@ -39,8 +39,22 @@ def train_step(self, data_batch, optimizer):


def test_extract_process():
test_dataset = ExampleDataset()
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
model = MMDataParallel(ExampleModel())

process = ExtractProcess()

results = process.extract(model, data_loader)
assert 'feat' in results
assert results['feat'].shape == (1, 128 * 1 * 1)


def test_multi_extract_process():
with pytest.raises(AssertionError):
process = ExtractProcess(
process = MultiExtractProcess(
pool_type='specified', backbone='resnet50', layer_indices=(-1, ))

test_dataset = ExampleDataset()
Expand All @@ -49,7 +63,7 @@ def test_extract_process():
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
model = MMDataParallel(ExampleModel())

process = ExtractProcess(
process = MultiExtractProcess(
pool_type='specified', backbone='resnet50', layer_indices=(0, 1, 2))

results = process.extract(model, data_loader)
Expand Down
4 changes: 2 additions & 2 deletions tools/analysis_tools/visualize_tsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mmselfsup.apis import set_random_seed
from mmselfsup.datasets import build_dataloader, build_dataset
from mmselfsup.models import build_algorithm
from mmselfsup.models.utils import ExtractProcess
from mmselfsup.models.utils import MultiExtractProcess
from mmselfsup.utils import get_root_logger


Expand Down Expand Up @@ -208,7 +208,7 @@ def main():
broadcast_buffers=False)

# build extraction processor and run
extractor = ExtractProcess(
extractor = MultiExtractProcess(
pool_type=args.pool_type, backbone='resnet50', layer_indices=layer_ind)
features = extractor.extract(model, data_loader, distributed=distributed)
labels = dataset.data_source.get_gt_labels()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env bash

set -e
set -x

CFG=$1
EPOCH=$2
PY_ARGS=${@:3}
GPUS=${GPUS:-8}
PORT=${PORT:-29500}

WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/

if [ ! -f $WORK_DIR/epoch_${EPOCH}.pth ]; then
echo "ERROR: File not exist: $WORK_DIR/epoch_${EPOCH}.pth"
exit
fi

python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
--checkpoint $WORK_DIR/epoch_${EPOCH}.pth \
--work_dir $WORK_DIR --launcher="pytorch" ${PY_ARGS}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env bash

set -e
set -x

CFG=$1
PRETRAIN=$2 # pretrained model
PY_ARGS=${@:3}
GPUS=${GPUS:-8}
PORT=${PORT:-29500}

# set work_dir according to config path and pretrained model to distinguish different models
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"

python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
--work_dir $WORK_DIR --launcher="pytorch" ${PY_ARGS}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env bash

set -e
set -x

PARTITION=$1
JOB_NAME=$2
CFG=$3
EPOCH=$4
PY_ARGS=${@:5}
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PORT=${PORT:-29500}
SRUN_ARGS=${SRUN_ARGS:-""}

WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/

if [ ! -f $WORK_DIR/epoch_${EPOCH}.pth ]; then
echo "ERROR: File not exist: $WORK_DIR/epoch_${EPOCH}.pth"
exit
fi

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
--checkpoint $WORK_DIR/epoch_${EPOCH}.pth \
--cfg-options dist_params.port=$PORT \
--work_dir $WORK_DIR --launcher="slurm" ${PY_ARGS}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env bash

set -e
set -x

PARTITION=$1
JOB_NAME=$2
CFG=$3
PRETRAIN=$4 # pretrained model
PY_ARGS=${@:5}
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PORT=${PORT:-29500}
SRUN_ARGS=${SRUN_ARGS:-""}

# set work_dir according to config path and pretrained model to distinguish different models
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
dist_params.port=$PORT \
--work_dir $WORK_DIR --launcher="slurm" ${PY_ARGS}
Loading