Skip to content

Commit

Permalink
delete not used code
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaomile committed Dec 13, 2023
1 parent b374173 commit f72fdad
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 35 deletions.
19 changes: 11 additions & 8 deletions mmagic/models/editors/consistency_models/consistencymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mmagic.registry import MODELS
from mmagic.structures import DataSample
from mmagic.utils import ForwardInputs
from .consistencymodel_utils import (device, get_generator, get_sample_fn,
from .consistencymodel_utils import (get_generator, get_sample_fn,
get_sigmas_karras, karras_sample)

ModelType = Union[Dict, nn.Module]
Expand Down Expand Up @@ -56,6 +56,9 @@ def __init__(self,
super().__init__(data_preprocessor=data_preprocessor)

self.num_classes = num_classes
self.device = torch.device('cpu')
if torch.cuda.is_available():
self.device = torch.device('cuda')

Check warning on line 61 in mmagic/models/editors/consistency_models/consistencymodel.py

View check run for this annotation

Codecov / codecov/patch

mmagic/models/editors/consistency_models/consistencymodel.py#L61

Added line #L61 was not covered by tests
if 'consistency' in training_mode:
self.distillation = True
else:
Expand Down Expand Up @@ -117,7 +120,7 @@ def __init__(self,
self.model.load_state_dict(

Check warning on line 120 in mmagic/models/editors/consistency_models/consistencymodel.py

View check run for this annotation

Codecov / codecov/patch

mmagic/models/editors/consistency_models/consistencymodel.py#L120

Added line #L120 was not covered by tests
torch.load(model_path, map_location='cpu'))

self.model.to(device())
self.model.to(self.device)

if sampler == 'multistep':
assert len(ts) > 0
Expand Down Expand Up @@ -147,7 +150,7 @@ def infer(self, class_id: Optional[int] = None):
(self.batch_size, 3, self.image_size, self.image_size),
steps=self.steps,
model_kwargs=self.model_kwargs,
device=device(),
device=self.device,
clip_denoised=self.clip_denoised,
sampler=self.sampler,
sigma_min=self.sigma_min,
Expand Down Expand Up @@ -216,18 +219,18 @@ def forward(self,
self.sigma_min,
self.sigma_max,
self.diffusion.rho,
device=device())
device=self.device)
else:
sigmas = get_sigmas_karras(
self.steps,
self.sigma_min,
self.sigma_max,
self.diffusion.rho,
device=device())
device=self.device)

noise = self.generator.randn(
*(self.batch_size, 3, self.image_size, self.image_size),
device=device()) * self.sigma_max
device=self.device) * self.sigma_max

sample_fn = get_sample_fn(self.sampler)

Expand Down Expand Up @@ -291,11 +294,11 @@ def label_fn(self, class_id):
'it should be within the range (0,num_classes).'
classes = torch.tensor(
[int(class_id) for i in range(self.batch_size)],
device=device())
device=self.device)
else:
classes = torch.randint(
low=0,
high=self.num_classes,
size=(self.batch_size, ),
device=device())
device=self.device)
return classes
29 changes: 3 additions & 26 deletions mmagic/models/editors/consistency_models/consistencymodel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,6 @@
import torch.nn as nn


def device():
"""return torch.device."""
if torch.cuda.is_available():
return torch.device('cuda')
return torch.device('cpu')


def get_weightings(weight_schedule, snrs, sigma_data):
"""return weightings."""
if weight_schedule == 'snr':
weightings = snrs
elif weight_schedule == 'snr+1':
weightings = snrs + 1
elif weight_schedule == 'karras':
weightings = snrs + 1.0 / sigma_data**2
elif weight_schedule == 'truncated-snr':
weightings = torch.clamp(snrs, min=1.0)
elif weight_schedule == 'uniform':
weightings = torch.ones_like(snrs)
else:
raise NotImplementedError()
return weightings


class SiLU(nn.Module):
"""PyTorch 1.7 has SiLU, but we support PyTorch 1.5."""

Expand Down Expand Up @@ -654,7 +630,7 @@ def __init__(self, num_samples, seed=0):
self.seed = seed
self.rng_cpu = torch.Generator()
if torch.cuda.is_available():
self.rng_cuda = torch.Generator(device())
self.rng_cuda = torch.Generator(torch.device('cuda'))

Check warning on line 633 in mmagic/models/editors/consistency_models/consistencymodel_utils.py

View check run for this annotation

Codecov / codecov/patch

mmagic/models/editors/consistency_models/consistencymodel_utils.py#L633

Added line #L633 was not covered by tests
self.set_seed(seed)

def get_global_size_and_indices(self, size):
Expand Down Expand Up @@ -737,7 +713,8 @@ def __init__(self, num_samples, seed=0):
self.rng_cpu = [torch.Generator() for _ in range(num_samples)]
if torch.cuda.is_available():
self.rng_cuda = [
torch.Generator(device()) for _ in range(num_samples)
torch.Generator(torch.device('cuda'))
for _ in range(num_samples)
]
self.set_seed(seed)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,20 @@
@pytest.mark.skipif(
'win' in platform.system().lower(),
reason='skip on windows due to limited RAM.')
class TestDeblurGanV2(TestCase):
class TestConsistencyModels(TestCase):

def test_init(self):
model = ConsistencyModel(
unet=unet_config,
denoiser=denoiser_config,
data_preprocessor=DataPreprocessor())
if torch.cuda.is_available():
self.assertIsInstance(model.device, torch.device('cuda'))
self.assertIsInstance(model, ConsistencyModel)
self.assertIsInstance(model.data_preprocessor, DataPreprocessor)
self.assertIsInstance(model.model, ConsistencyUNetModel)
self.assertIsInstance(model.diffusion, KarrasDenoiser)

unet_cfg = deepcopy(unet_config)
diffuse_cfg = deepcopy(denoiser_config)
unet = MODELS.build(unet_cfg)
Expand All @@ -114,6 +117,11 @@ def test_onestep_infer(self):
for datasample in result:
assert datasample.fake_img.shape == (3, model.image_size,
model.image_size)
result, labels = model.infer()
assert len(result) == model.batch_size
assert len(labels) == model.batch_size
for datasample in result:
assert datasample.shape == (model.image_size, model.image_size, 3)

def test_multistep_infer(self):
model = MODELS.build(config_multistep)
Expand All @@ -127,6 +135,11 @@ def test_multistep_infer(self):
for datasample in result:
assert datasample.fake_img.shape == (3, model.image_size,
model.image_size)
result, labels = model.infer()
assert len(result) == model.batch_size
assert len(labels) == model.batch_size
for datasample in result:
assert datasample.shape == (model.image_size, model.image_size, 3)


def teardown_module():
Expand Down

0 comments on commit f72fdad

Please sign in to comment.